Skip to content

Commit 152f83c

Browse files
refac: add ESNCell built on LuxCore
1 parent cb409df commit 152f83c

File tree

7 files changed

+394
-25
lines changed

7 files changed

+394
-25
lines changed

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ version = "0.11.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
10+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1013
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1114
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1215
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1317
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1418

1519
[weakdeps]
@@ -27,17 +31,21 @@ RCSparseArraysExt = "SparseArrays"
2731
[compat]
2832
Adapt = "4.1.1"
2933
Aqua = "0.8"
34+
ArrayInterface = "7.19.0"
3035
CellularAutomata = "0.0.6"
3136
Compat = "4.16.0"
37+
ConcreteStructs = "0.2.3"
3238
DifferentialEquations = "7.16.1"
3339
LIBSVM = "0.8"
3440
LinearAlgebra = "1.10"
41+
LuxCore = "1.3.0"
3542
MLJLinearModels = "0.9.2, 0.10"
3643
NNlib = "0.9.26"
3744
Random = "1.10"
3845
Reexport = "1.2.2"
3946
SafeTestsets = "0.1"
4047
SparseArrays = "1.10"
48+
Static = "1.2.0"
4149
Statistics = "1.10"
4250
Test = "1"
4351
WeightInitializers = "1.0.5"

src/ReservoirComputing.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
11
module ReservoirComputing
22

33
using Adapt: adapt
4+
using ArrayInterface: ArrayInterface
45
using Compat: @compat
6+
using ConcreteStructs: @concrete
57
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
8+
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
9+
setup, apply, replicate
10+
import LuxCore: initialparameters, initialstates, statelength, outputsize
611
using NNlib: fast_act, sigmoid
712
using Random: Random, AbstractRNG, randperm
13+
using Static: StaticBool, StaticInt, StaticSymbol,
14+
True, False, static, known, dynamic, StaticInteger
815
using Reexport: Reexport, @reexport
916
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1017
@reexport using WeightInitializers
18+
@reexport using LuxCore: setup, apply
1119

1220
abstract type AbstractReservoirComputer end
1321

22+
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
23+
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
24+
const IntegerType = Union{Integer,StaticInteger}
25+
1426
@compat(public, (create_states))
1527

28+
#layers
29+
include("layers/lux_layers.jl")
30+
include("layers/esn_cell.jl")
1631
#general
1732
include("generics/states.jl")
1833
include("generics/predict.jl")
@@ -28,17 +43,21 @@ include("models/deepesn.jl")
2843
include("models/hybridesn.jl")
2944
include("models/esn_predict.jl")
3045

46+
47+
48+
export ESNCell, StatefulLayer, Readout, ReservoirChain, Collect, collectstates, train!, predict
49+
3150
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
3251
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
3352
export StandardRidge
3453
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
35-
modified_lm, scaled_rand, weighted_init, weighted_minimal
54+
modified_lm, scaled_rand, weighted_init, weighted_minimal
3655
export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backward,
37-
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
38-
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
39-
selfloop_forward_connection, simple_cycle, true_double_cycle
56+
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
57+
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
58+
selfloop_forward_connection, simple_cycle, true_double_cycle
4059
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
41-
scale_radius!, self_loop!, simple_cycle!
60+
scale_radius!, self_loop!, simple_cycle!
4261
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
4362
export train
4463
export ESN, HybridESN, KnowledgeModel, DeepESN

src/generics/linear_regression.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Returns a training method for `train` based on ridge regression.
66
The equations for ridge regression are as follows:
77
88
```math
9-
\mathbf{w} = (\mathbf{X}^\top \mathbf{X} +
9+
\mathbf{w} = (\mathbf{X}^\top \mathbf{X} +
1010
\lambda \mathbf{I})^{-1} \mathbf{X}^\top \mathbf{y}
1111
```
1212
@@ -21,20 +21,48 @@ struct StandardRidge
2121
reg::Number
2222
end
2323

24-
function StandardRidge(::Type{T}, reg) where {T <: Number}
24+
function StandardRidge(::Type{T}, reg) where {T<:Number}
2525
return StandardRidge(T.(reg))
2626
end
2727

2828
function StandardRidge()
2929
return StandardRidge(0.0)
3030
end
3131

32+
function train!(rc::ReservoirChain, train_data::AbstractArray,
33+
target_data::AbstractArray, ps, st::NamedTuple, sr::StandardRidge=StandardRidge(0.0);
34+
return_states::Bool=false)
35+
states = collectstates(rc, train_data, ps, st)
36+
readout = train(sr, states, target_data)
37+
ps, st = addreadout!(rc, readout, ps, st)
38+
39+
if return_states
40+
return (ps, st), states
41+
else
42+
return ps, st
43+
end
44+
end
45+
3246
function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray)
3347
n_states = size(states, 1)
3448
A = [states'; sqrt(sr.reg) * I(n_states)]
3549
b = [target_data'; zeros(n_states, size(target_data, 1))]
3650
F = qr(A)
3751
Wt = F \ b
3852
output_layer = Matrix(Wt')
39-
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
53+
return output_layer
4054
end
55+
56+
function addreadout!(rc::ReservoirChain, readout_matrix::AbstractArray, ps, st::NamedTuple) #make sure the compile infers
57+
ro_param = (; weight=readout_matrix)
58+
new_ps = (;)
59+
for ((name, layer), param) in zip(pairs(rc.layers), ps)
60+
if layer isa Readout
61+
param = merge(param, ro_param)
62+
end
63+
new_ps = merge(new_ps, (; name => param))
64+
end
65+
return new_ps, st
66+
end
67+
68+
#use a recursion to make it more compiler safe

src/generics/predict.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ abstract type AbstractOutputLayer end
22
abstract type AbstractPrediction end
33

44
#general output layer struct
5-
struct OutputLayer{T, I, S, L} <: AbstractOutputLayer
5+
struct OutputLayer{T,I,S,L} <: AbstractOutputLayer
66
training_method::T
77
output_matrix::I
88
out_size::S
@@ -39,7 +39,7 @@ struct Generative{T} <: AbstractPrediction
3939
prediction_len::T
4040
end
4141

42-
struct Predictive{I, T} <: AbstractPrediction
42+
struct Predictive{I,T} <: AbstractPrediction
4343
prediction_data::I
4444
prediction_len::T
4545
end
@@ -67,8 +67,8 @@ function Predictive(prediction_data::AbstractArray)
6767
end
6868

6969
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative,
70-
x, output_layer::AbstractOutputLayer, args...;
71-
initial_conditions = output_layer.last_value)
70+
x, output_layer::AbstractOutputLayer, args...;
71+
initial_conditions=output_layer.last_value)
7272
#x = last_state
7373
prediction_len = prediction.prediction_len
7474
train_method = output_layer.training_method
@@ -86,7 +86,7 @@ function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative
8686
end
8787

8888
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Predictive,
89-
x, output_layer::AbstractOutputLayer, args...; kwargs...)
89+
x, output_layer::AbstractOutputLayer, args...; kwargs...)
9090
prediction_len = prediction.prediction_len
9191
train_method = output_layer.training_method
9292
out_size = output_layer.out_size
@@ -117,3 +117,15 @@ function store_results!(training_method, out, output, i)
117117
output[:, i] = out
118118
return out
119119
end
120+
121+
function predict(rc, steps::Int, ps, st; initialdata=nothing)
122+
if initialdata == nothing
123+
initialdata = rand(Float32, 3)
124+
end
125+
output = zeros(size(initialdata, 1), steps)
126+
for step in 1:steps
127+
initialdata, st = apply(rc, initialdata, ps, st)
128+
output[:, step] = initialdata
129+
end
130+
return output, st
131+
end

src/generics/states.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
abstract type AbstractStates end
1+
abstract type AbstractStates <: Function end
22
abstract type AbstractPaddedStates <: AbstractStates end
3-
abstract type NonLinearAlgorithm end
3+
abstract type NonLinearAlgorithm <: Function end
44

55
function pad_state!(states_type::AbstractPaddedStates, x_pad, x)
66
x_pad[1, :] .= states_type.padding
@@ -60,7 +60,7 @@ julia> new_mat = states(test_mat)
6060
struct StandardStates <: AbstractStates end
6161

6262
function (::StandardStates)(nla_type::NonLinearAlgorithm,
63-
state, inp)
63+
state, inp)
6464
return nla(nla_type, state)
6565
end
6666

@@ -137,7 +137,7 @@ function (::ExtendedStates)(vect::AbstractVector, inp::AbstractVector)
137137
end
138138

139139
function (states_type::ExtendedStates)(nla_type::NonLinearAlgorithm,
140-
state::AbstractVecOrMat, inp::AbstractVecOrMat)
140+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
141141
return nla(nla_type, states_type(state, inp))
142142
end
143143

@@ -194,7 +194,7 @@ struct PaddedStates{T} <: AbstractPaddedStates
194194
padding::T
195195
end
196196

197-
function PaddedStates(; padding = 1.0)
197+
function PaddedStates(; padding=1.0)
198198
return PaddedStates(padding)
199199
end
200200

@@ -209,7 +209,7 @@ function (states_type::PaddedStates)(vect::AbstractVector)
209209
end
210210

211211
function (states_type::PaddedStates)(nla_type::NonLinearAlgorithm,
212-
state::AbstractVecOrMat, inp::AbstractVecOrMat)
212+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
213213
return nla(nla_type, states_type(state))
214214
end
215215

@@ -272,17 +272,17 @@ struct PaddedExtendedStates{T} <: AbstractPaddedStates
272272
padding::T
273273
end
274274

275-
function PaddedExtendedStates(; padding = 1.0)
275+
function PaddedExtendedStates(; padding=1.0)
276276
return PaddedExtendedStates(padding)
277277
end
278278

279279
function (states_type::PaddedExtendedStates)(nla_type::NonLinearAlgorithm,
280-
state::AbstractVecOrMat, inp::AbstractVecOrMat)
280+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
281281
return nla(nla_type, states_type(state, inp))
282282
end
283283

284284
function (states_type::PaddedExtendedStates)(state::AbstractVecOrMat,
285-
inp::AbstractVecOrMat)
285+
inp::AbstractVecOrMat)
286286
x_pad = PaddedStates(states_type.padding)(state)
287287
x_ext = ExtendedStates()(x_pad, inp)
288288
return x_ext
@@ -539,7 +539,7 @@ function (::NLAT2)(x_old::AbstractVector)
539539

540540
for idx in eachindex(x_old)
541541
if firstindex(x_old) < idx < lastindex(x_old) && isodd(idx)
542-
x_new[idx, :] .= x_old[idx - 1, :] .* x_old[idx - 2, :]
542+
x_new[idx, :] .= x_old[idx-1, :] .* x_old[idx-2, :]
543543
end
544544
end
545545

@@ -628,7 +628,7 @@ function (::NLAT3)(x_old::AbstractVector)
628628

629629
for idx in eachindex(x_old)
630630
if firstindex(x_old) < idx < lastindex(x_old) && isodd(idx)
631-
x_new[idx] = x_old[idx - 1] * x_old[idx + 1]
631+
x_new[idx] = x_old[idx-1] * x_old[idx+1]
632632
end
633633
end
634634

@@ -645,7 +645,7 @@ Implement a partial squaring of the states as described in [Barbosa2021](@cite).
645645
```math
646646
\begin{equation}
647647
g(r_i) =
648-
\begin{cases}
648+
\begin{cases}
649649
r_i^2, & \text{if } i \leq \eta_r N, \\
650650
r_i, & \text{if } i > \eta_r N.
651651
\end{cases}

src/layers/esn_cell.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
@concrete struct ESNCell <: AbstractReservoirRecurrentCell
2+
activation
3+
in_dims <: IntegerType
4+
out_dims <: IntegerType
5+
init_bias
6+
init_reservoir
7+
init_input
8+
#init_feedback::F
9+
init_state
10+
leak_coefficient
11+
use_bias <: StaticBool
12+
end
13+
14+
function ESNCell((in_dims, out_dims)::Pair{<:Int,<:Int}, activation=tanh;
15+
use_bias::BoolType=False(), init_bias=zeros32, init_reservoir=rand_sparse,
16+
init_input=weighted_init, init_state=randn32, leak_coefficient=1.0)
17+
return ESNCell(activation, in_dims, out_dims, init_bias, init_reservoir,
18+
init_input, init_state, leak_coefficient, use_bias)
19+
end
20+
21+
function initialparameters(rng::AbstractRNG, esn::ESNCell)
22+
ps = (input_matrix=esn.init_input(rng, esn.out_dims, esn.in_dims),
23+
reservoir_matrix=esn.init_reservoir(rng, esn.out_dims, esn.out_dims))
24+
if has_bias(esn)
25+
ps = merge(ps, (bias=esn.init_bias(rng, esn.out_dims),))
26+
end
27+
return ps
28+
end
29+
30+
function initialstates(rng::AbstractRNG, esn::ESNCell)
31+
return (rng=sample_replicate(rng),)
32+
end
33+
34+
function (esn::ESNCell)(inp::AbstractArray, ps, st::NamedTuple)
35+
rng = replicate(st.rng)
36+
hidden_state = init_hidden_state(rng, esn, inp)
37+
return esn((inp, (hidden_state,)), ps, merge(st, (; rng)))
38+
end
39+
40+
function (esn::ESNCell)((inp, (hidden_state,))::InputType, ps, st::NamedTuple)
41+
T = eltype(inp)
42+
if has_bias(esn)
43+
candidate_h = esn.activation.(ps.input_matrix * inp .+
44+
ps.reservoir_matrix * hidden_state .+ ps.bias)
45+
else
46+
candidate_h = esn.activation.(ps.input_matrix * inp .+
47+
ps.reservoir_matrix * hidden_state)
48+
end
49+
h_new = (T(1.0) - esn.leak_coefficient) .* hidden_state .+
50+
esn.leak_coefficient .* candidate_h
51+
return (h_new, (h_new,)), st
52+
end
53+
54+
function Base.show(io::IO, esn::ESNCell)
55+
print(io, "ESNCell($(esn.in_dims) => $(esn.out_dims)")
56+
if esn.leak_coefficient != eltype(esn.leak_coefficient)(1.0)
57+
print(io, ", leak_coefficient=$(esn.leak_coefficient)")
58+
end
59+
has_bias(esn) || print(io, ", use_bias=false")
60+
print(io, ")")
61+
end

0 commit comments

Comments
 (0)