Skip to content

Commit ed10b6d

Browse files
feat: more robust ESN formulation
1 parent 26ef75b commit ed10b6d

File tree

6 files changed

+112
-9
lines changed

6 files changed

+112
-9
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.11.4"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
10+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1213
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -34,6 +35,7 @@ CellularAutomata = "0.0.6"
3435
Compat = "4.16.0"
3536
ConcreteStructs = "0.2.3"
3637
DifferentialEquations = "7.16.1"
38+
Functors = "0.5.2"
3739
JET = "0.9.20"
3840
LIBSVM = "0.8"
3941
LinearAlgebra = "1.10"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ We can either use the provided `ESN` or build one from scratch.
103103
We showcase the second option:
104104

105105
```julia
106+
using ReservoirComputing
106107
input_size = 3
107108
res_size = 300
108109
esn = ReservoirChain(

src/ReservoirComputing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReservoirComputing
33
using ArrayInterface: ArrayInterface
44
using Compat: @compat
55
using ConcreteStructs: @concrete
6+
#using Functors
67
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
78
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
89
setup, apply, replicate

src/models/esn.jl

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,111 @@
1+
@concrete struct ESN <: AbstractLuxContainerLayer{(:cell, :states_modifiers, :readout)}
2+
cell
3+
states_modifiers
4+
readout
5+
end
6+
7+
_wrap_layer(x) = x isa Function ? WrappedFunction(x) : x
8+
_wrap_layers(xs::Tuple) = map(_wrap_layer, xs)
9+
110
function ESN(in_dims::IntegerType, res_dims::IntegerType, out_dims::IntegerType, activation=tanh;
211
readout_activation=identity,
312
state_modifiers=(),
413
kwargs...)
5-
cell = ESNCell(in_dims => res_dims, activation; kwargs...)
6-
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
7-
Tuple(state_modifiers) : (state_modifiers,)
14+
cell = StatefulLayer(ESNCell(in_dims => res_dims, activation; kwargs...))
15+
mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
16+
Tuple(state_modifiers) : (state_modifiers,)
17+
mods = _wrap_layers(mods_tuple)
818
ro = LinearReadout(res_dims => out_dims, readout_activation)
9-
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
19+
return ESN(cell, mods, ro)
20+
end
21+
22+
function initialparameters(rng::AbstractRNG, esn::ESN)
23+
ps_cell = initialparameters(rng, esn.cell)
24+
ps_mods = map(l -> initialparameters(rng, l), esn.states_modifiers) |> Tuple
25+
ps_ro = initialparameters(rng, esn.readout)
26+
return (cell=ps_cell, states_modifiers=ps_mods, readout=ps_ro)
27+
end
28+
29+
function initialstates(rng::AbstractRNG, esn::ESN)
30+
st_cell = initialstates(rng, esn.cell)
31+
st_mods = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple
32+
st_ro = initialstates(rng, esn.readout)
33+
return (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
34+
end
35+
36+
@inline function _apply_seq(layers::Tuple, x, ps::Tuple, st::Tuple)
37+
n = length(layers)
38+
new_st_parts = Vector{Any}(undef, n)
39+
@inbounds for i in 1:n
40+
x, sti = apply(layers[i], x, ps[i], st[i])
41+
new_st_parts[i] = sti
42+
end
43+
return x, tuple(new_st_parts...)
44+
end
45+
46+
function (m::ESN)(x, ps, st)
47+
y, st_cell = apply(m.cell, x, ps.cell, st.cell)
48+
y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, st.states_modifiers)
49+
y, st_ro = apply(m.readout, y, ps.readout, st.readout)
50+
return y, (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
51+
end
52+
53+
function reset_carry(esn::ESN, st; mode=:zeros, value=nothing, rng=nothing)
54+
# Find current carry & infer shape/type
55+
c = get(st.cell, :carry, nothing)
56+
if c === nothing
57+
outd = esn.cell.cell.out_dims
58+
T = Float32
59+
sz = (outd, 1)
60+
else
61+
h = c[1] # carry is usually a 1-tuple (h,)
62+
T = eltype(h)
63+
sz = size(h)
64+
end
65+
66+
new_h = begin
67+
if mode === :zeros
68+
zeros(T, sz)
69+
elseif mode === :randn
70+
rng = rng === nothing ? Random.default_rng() : rng
71+
randn(rng, T, sz...)
72+
elseif mode === :value
73+
@assert value !== nothing "Provide `value=` when mode=:value"
74+
fill(T(value), sz)
75+
else
76+
error("Unknown mode=$(mode). Use :zeros, :randn, or :value.")
77+
end
78+
end
79+
80+
new_cell = merge(st.cell, (; carry=(new_h,)))
81+
return (cell=new_cell, states_modifiers=st.states_modifiers, readout=st.readout)
82+
end
83+
84+
_set_readout_weight(ps_readout::NamedTuple, W) = merge(ps_readout, (; weight=W))
85+
86+
function train!(m::ESN, train_data::AbstractMatrix, target_data::AbstractMatrix,
87+
ps, st, train_method=StandardRidge(0.0);
88+
washout::Int=0, return_states::Bool=false)
89+
90+
newst = st
91+
collected = Vector{Any}(undef, size(train_data, 2))
92+
@inbounds for (t, x) in enumerate(eachcol(train_data))
93+
y, st_cell = apply(m.cell, x, ps.cell, newst.cell)
94+
y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, newst.states_modifiers)
95+
collected[t] = copy(y)
96+
newst = (cell=st_cell, states_modifiers=st_mods, readout=newst.readout)
97+
end
98+
states = eltype(train_data).(reduce(hcat, collected))
99+
100+
states_wo, targets_wo =
101+
washout > 0 ? _apply_washout(states, target_data, washout) : (states, target_data)
102+
103+
W = train(train_method, states_wo, targets_wo)
104+
ps2 = (cell=ps.cell,
105+
states_modifiers=ps.states_modifiers,
106+
readout=_set_readout_weight(ps.readout, W))
107+
108+
return return_states ? ((ps2, newst), states_wo) : (ps2, newst)
10109
end
11110

12111
_basefuncstr(x) = sprint(show, x)

src/states.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ point with the input that it receives.
5454
)
5555
),
5656
NLAT2(),
57-
Readout(300+3 => 3)
57+
LinearReadout(300+3 => 3)
5858
)
5959
```
6060

src/train.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ end
4444
washout::Int=0, return_states::Bool=false)
4545
4646
Trains the Reservoir Computer by creating the reservoir states from `train_data`,
47-
and then fiting the last [`Readout`](@ref) layer by (ridge)
47+
and then fiting the last [`LinearReadout`](@ref) layer by (ridge)
4848
linear regression onto `target_data`. The learned weights are written into `ps`, and.
4949
The returned state is the final state after running through the full sequence.
5050
5151
## Arguments
5252
53-
- `rc`: A [`ReservoirChain`](@ref) whose last trainable layer is a `Readout`.
53+
- `rc`: A [`ReservoirChain`](@ref) whose last trainable layer is a `LinearReadout`.
5454
- `train_data`: input sequence (columns are time steps).
5555
- `target_data`: targets aligned with `train_data`.
5656
- `ps, st`: current parameters and state.
@@ -71,7 +71,7 @@ The returned state is the final state after running through the full sequence.
7171
## Notes
7272
7373
- Features are produced by `collectstates(rc, train_data, ps, st)`. If you rely on
74-
the implicit collection of a [`Readout`](@ref), make sure that readout was created with
74+
the implicit collection of a [`LinearReadout`](@ref), make sure that readout was created with
7575
`include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the chain.
7676
"""
7777
function train!(rc::ReservoirChain, train_data, target_data, ps, st,
@@ -122,7 +122,7 @@ end
122122
Kq = _quote_keys(K)
123123
tailKq = _quote_keys(tailK)
124124

125-
head_val = :((getfield(layers, 1) isa Readout)
125+
head_val = :((getfield(layers, 1) isa LinearReadout)
126126
? _setweight_rt(getfield(ps, 1), W)
127127
: getfield(ps, 1))
128128

0 commit comments

Comments
 (0)