Skip to content

Commit 9dc01cb

Browse files
refac: finalize HybridESN, improve ESN generics
1 parent 07e2020 commit 9dc01cb

File tree

9 files changed

+147
-219
lines changed

9 files changed

+147
-219
lines changed

README.md

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,10 @@ We showcase the second option:
106106
using ReservoirComputing
107107
input_size = 3
108108
res_size = 300
109-
esn = ReservoirChain(
110-
StatefulLayer(
111-
ESNCell(
112-
input_size => res_size;
113-
init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300)
114-
)
115-
),
116-
NLAT2(),
117-
Readout(res_size => input_size) # autoregressive so out_dims == in_dims
109+
esn = ESN(input_size, res_size, input_size;
110+
init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
111+
state_modifiers=NLAT2
118112
)
119-
# alternative:
120-
# esn = ESN(input_size, res_size, input_size;
121-
# init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300)
122-
# )
123113
```
124114

125115
### 3. Train the Echo State Network

ext/RCCellularAutomataExt.jl

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,6 @@ import ReservoirComputing: RECACell, RECA
55
using CellularAutomata
66
using Random: randperm
77

8-
function (reca::RECACell)((inp, (ca_prev,)), ps, st::NamedTuple)
9-
rm = reca.enc
10-
T = eltype(inp)
11-
ca0 = T.(encoding(rm, inp, T.(ca_prev)))
12-
ca = CellularAutomaton(reca.automaton, ca0, rm.generations + 1)
13-
evo = ca.evolution
14-
feat2T = evo[2:end, :]
15-
feats = reshape(permutedims(feat2T), rm.states_size)
16-
ca_last = evo[end, :]
17-
return (T.(feats), (T.(ca_last),)), st
18-
end
19-
20-
function (reca::RECACell)(inp::AbstractVector, ps, st::NamedTuple)
21-
ca = st.ca
22-
return reca((inp, (ca,)), ps, st)
23-
end
24-
25-
function RECA(in_dims::IntegerType,
26-
out_dims::IntegerType,
27-
automaton;
28-
input_encoding::AbstractInputEncoding=RandomMapping(),
29-
generations::Integer=8,
30-
state_modifiers=(),
31-
readout_activation=identity)
32-
33-
rm = create_encoding(input_encoding, in_dims, generations)
34-
cell = RECACell(automaton, rm)
35-
36-
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
37-
Tuple(state_modifiers) : (state_modifiers,)
38-
39-
ro = Readout(rm.states_size => out_dims, readout_activation)
40-
41-
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
42-
end
43-
448
function RandomMapping(; permutations=8, expansion_size=40)
459
RandomMapping(permutations, expansion_size)
4610
end
@@ -56,23 +20,6 @@ function create_encoding(rm::RandomMapping, in_dims::IntegerType, generations::I
5620
return RandomMaps(rm.permutations, rm.expansion_size, generations, maps, states_size, ca_size)
5721
end
5822

59-
60-
function reca_create_states(rm::RandomMaps, automata, input_data)
61-
train_time = size(input_data, 2)
62-
states = zeros(rm.states_size, train_time)
63-
init_ca = zeros(rm.expansion_size * rm.permutations)
64-
65-
for i in 1:train_time
66-
init_ca = encoding(rm, input_data[:, i], init_ca)
67-
ca = CellularAutomaton(automata, init_ca, rm.generations + 1)
68-
ca_states = ca.evolution[2:end, :]
69-
states[:, i] = reshape(transpose(ca_states), rm.states_size)
70-
init_ca = ca.evolution[end, :]
71-
end
72-
73-
return states
74-
end
75-
7623
function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
7724
input_size = size(input_vector, 1)
7825
#single_encoded_size = Int(size(tot_encoded_vector, 1)/permutations)
@@ -119,4 +66,40 @@ function mapping(input_size, mapped_vector_size)
11966
return randperm(mapped_vector_size)[1:input_size]
12067
end
12168

69+
function (reca::RECACell)((inp, (ca_prev,)), ps, st::NamedTuple)
70+
rm = reca.enc
71+
T = eltype(inp)
72+
ca0 = T.(encoding(rm, inp, T.(ca_prev)))
73+
ca = CellularAutomaton(reca.automaton, ca0, rm.generations + 1)
74+
evo = ca.evolution
75+
feat2T = evo[2:end, :]
76+
feats = reshape(permutedims(feat2T), rm.states_size)
77+
ca_last = evo[end, :]
78+
return (T.(feats), (T.(ca_last),)), st
79+
end
80+
81+
function (reca::RECACell)(inp::AbstractVector, ps, st::NamedTuple)
82+
ca = st.ca
83+
return reca((inp, (ca,)), ps, st)
84+
end
85+
86+
function RECA(in_dims::IntegerType,
87+
out_dims::IntegerType,
88+
automaton;
89+
input_encoding::AbstractInputEncoding=RandomMapping(),
90+
generations::Integer=8,
91+
state_modifiers=(),
92+
readout_activation=identity)
93+
94+
rm = create_encoding(input_encoding, in_dims, generations)
95+
cell = RECACell(automaton, rm)
96+
97+
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
98+
Tuple(state_modifiers) : (state_modifiers,)
99+
100+
ro = Readout(rm.states_size => out_dims, readout_activation)
101+
102+
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
103+
end
104+
122105
end #module

src/ReservoirComputing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module ReservoirComputing
33
using ArrayInterface: ArrayInterface
44
using Compat: @compat
55
using ConcreteStructs: @concrete
6-
#using Functors
76
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
87
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
98
setup, apply, replicate
@@ -20,6 +19,7 @@ using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
2019
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
2120
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
2221
const IntegerType = Union{Integer,StaticInteger}
22+
const RCFields = (:cells, :states_modifiers, :readout)
2323

2424
abstract type AbstractReservoirComputer{Fields} <: AbstractLuxContainerLayer{Fields} end
2525

@@ -38,10 +38,10 @@ include("train.jl")
3838
include("inits/inits_components.jl")
3939
include("inits/esn_inits.jl")
4040
#full models
41-
include("models/esn_utils.jl")
41+
include("models/esn_generics.jl")
4242
include("models/esn.jl")
43-
include("models/deepesn.jl")
44-
include("models/hybridesn.jl")
43+
include("models/esn_deep.jl")
44+
include("models/esn_hybrid.jl")
4545
#extensions
4646
include("extensions/reca.jl")
4747

@@ -60,7 +60,7 @@ export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backwar
6060
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
6161
scale_radius!, self_loop!, simple_cycle!
6262
export train
63-
export ESN, HybridESN, KnowledgeModel, DeepESN
63+
export ESN, HybridESN, DeepESN
6464
#reca
6565
export RECACell, RECA
6666
export RandomMapping, RandomMaps

src/layers/esn_cell.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ end
7777

7878
function ESNCell((in_dims, out_dims)::Pair{<:IntegerType,<:IntegerType}, activation=tanh;
7979
use_bias::BoolType=False(), init_bias=zeros32, init_reservoir=rand_sparse,
80-
init_input=weighted_init, init_state=randn32, leak_coefficient=1.0)
80+
init_input=scaled_rand, init_state=randn32, leak_coefficient=1.0)
8181
return ESNCell(activation, in_dims, out_dims, init_bias, init_reservoir,
8282
init_input, init_state, leak_coefficient, use_bias)
8383
end

src/models/esn.jl

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
ESN(in_dims, res_dims, out_dims, activation=tanh;
3+
leak_coefficient=1.0, init_reservoir=rand_sparse, init_input=scaled_rand,
4+
init_bias=zeros32, init_state=randn32, use_bias=false,
5+
state_modifiers=(), readout_activation=identity)
6+
7+
Build a ESN.
8+
"""
19
@concrete struct ESN <: AbstractEchoStateNetwork{(:cell, :states_modifiers, :readout)}
210
cell
311
states_modifiers
@@ -31,12 +39,17 @@ function initialstates(rng::AbstractRNG, esn::ESN)
3139
return (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
3240
end
3341

34-
function (esn::ESN)(inp, ps, st)
42+
function _partial_apply(esn::ESN, inp, ps, st)
3543
out, st_cell = apply(esn.cell, inp, ps.cell, st.cell)
3644
out, st_mods = _apply_seq(
3745
esn.states_modifiers, out, ps.states_modifiers, st.states_modifiers)
46+
return out, (cell=st_cell, states_modifiers=st_mods)
47+
end
48+
49+
function (esn::ESN)(inp, ps, st)
50+
out, new_st = _partial_apply(esn, inp, ps, st)
3851
out, st_ro = apply(esn.readout, out, ps.readout, st.readout)
39-
return out, (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
52+
return out, merge(new_st, (readout=st_ro,))
4053
end
4154

4255
function resetcarry(rng::AbstractRNG, esn::ESN, st; init_carry=nothing)
@@ -59,29 +72,3 @@ function resetcarry(rng::AbstractRNG, esn::ESN, st; init_carry=nothing)
5972
new_cell = merge(st.cell, (; carry=new_state))
6073
return (cell=new_cell, states_modifiers=st.states_modifiers, readout=st.readout)
6174
end
62-
63-
function collectstates(esn::ESN, data::AbstractMatrix, ps, st::NamedTuple)
64-
newst = st
65-
collected = Any[]
66-
for inp in eachcol(data)
67-
cell_y, st_cell = apply(esn.cell, inp, ps.cell, newst.cell)
68-
state_t, st_mods = _apply_seq(
69-
esn.states_modifiers, cell_y, ps.states_modifiers, newst.states_modifiers)
70-
push!(collected, copy(state_t))
71-
newst = (cell=st_cell, states_modifiers=st_mods, readout=newst.readout)
72-
end
73-
states = eltype(data).(reduce(hcat, collected))
74-
@assert !isempty(collected)
75-
states_raw = reduce(hcat, collected)
76-
states = eltype(data).(states_raw)
77-
return states, newst
78-
end
79-
80-
function addreadout!(::ESN, output_matrix::AbstractMatrix,
81-
ps::NamedTuple, st::NamedTuple)
82-
@assert hasproperty(ps, :readout)
83-
new_readout = _set_readout_weight(ps.readout, output_matrix)
84-
return (cell=ps.cell,
85-
states_modifiers=ps.states_modifiers,
86-
readout=new_readout), st
87-
end

src/models/deepesn.jl renamed to src/models/esn_deep.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
activation=tanh;
66
leak_coefficient=1.0,
77
init_reservoir=rand_sparse,
8-
init_input=weighted_init,
8+
init_input=scaled_rand,
99
init_bias=zeros32,
1010
init_state=randn32,
1111
use_bias=false,
@@ -96,7 +96,7 @@ function initialstates(rng::AbstractRNG, desn::DeepESN)
9696
return (cells=st_cells, states_modifiers=st_mods, readout=st_ro)
9797
end
9898

99-
function (desn::DeepESN)(inp, ps, st)
99+
function _partial_apply(desn::DeepESN, inp, ps, st)
100100
inp_t = inp
101101
n_layers = length(desn.cells)
102102
new_cell_st = Vector{Any}(undef, n_layers)
@@ -108,15 +108,19 @@ function (desn::DeepESN)(inp, ps, st)
108108
ps.states_modifiers[idx], st.states_modifiers[idx])
109109
new_mods_st[idx] = st_mods_i
110110
end
111-
inp_t, st_ro = apply(desn.readout, inp_t, ps.readout, st.readout)
112111

113112
return inp_t, (;
114113
cells=tuple(new_cell_st...),
115114
states_modifiers=tuple(new_mods_st...),
116-
readout=st_ro,
117115
)
118116
end
119117

118+
function (desn::DeepESN)(inp, ps, st)
119+
out, new_st = _partial_apply(desn, inp, ps, st)
120+
inp_t, st_ro = apply(desn.readout, out, ps.readout, st.readout)
121+
return inp_t, merge(new_st, (readout=st_ro,))
122+
end
123+
120124
function resetcarry(rng::AbstractRNG, desn::DeepESN, st; init_carry=nothing)
121125
n_layers = length(desn.cells)
122126

@@ -189,12 +193,3 @@ end
189193

190194
collectstates(m::DeepESN, data::AbstractVector, ps, st::NamedTuple) =
191195
collectstates(m, reshape(data, :, 1), ps, st)
192-
193-
function addreadout!(::DeepESN, output_matrix::AbstractMatrix,
194-
ps::NamedTuple, st::NamedTuple)
195-
@assert hasproperty(ps, :readout)
196-
new_readout = _set_readout_weight(ps.readout, output_matrix)
197-
return (cells=ps.cells,
198-
states_modifiers=ps.states_modifiers,
199-
readout=new_readout), st
200-
end

src/models/esn_utils.jl renamed to src/models/esn_generics.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,25 @@ _set_readout_weight(ps_readout::NamedTuple, wro) = merge(ps_readout, (; weight=w
4343
function resetcarry(rng::AbstractRNG, esn, ps, st; init_carry=nothing)
4444
return ps, resetcarry(rng, esn, st; init_carry=init_carry)
4545
end
46+
47+
function collectstates(esn::AbstractEchoStateNetwork, data::AbstractMatrix, ps, st::NamedTuple)
48+
newst = st
49+
collected = Any[]
50+
for inp in eachcol(data)
51+
state_t, partial_st = _partial_apply(esn, inp, ps, newst)
52+
push!(collected, copy(state_t))
53+
newst = merge(partial_st, (readout=newst.readout,))
54+
end
55+
states = eltype(data).(reduce(hcat, collected))
56+
@assert !isempty(collected)
57+
states_raw = reduce(hcat, collected)
58+
states = eltype(data).(states_raw)
59+
return states, newst
60+
end
61+
62+
function addreadout!(::AbstractEchoStateNetwork, output_matrix::AbstractMatrix,
63+
ps::NamedTuple, st::NamedTuple)
64+
@assert hasproperty(ps, :readout)
65+
new_readout = _set_readout_weight(ps.readout, output_matrix)
66+
return merge(ps, (readout=new_readout,)), st
67+
end

src/models/esn_hybrid.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@concrete struct HybridESN <: AbstractEchoStateNetwork{(:cell, :states_modifiers, :readout, :knowledge_model)}
2+
cell
3+
knowledge_model
4+
states_modifiers
5+
readout
6+
end
7+
8+
function HybridESN(km,
9+
km_dims::IntegerType, in_dims::IntegerType,
10+
res_dims::IntegerType, out_dims::IntegerType,
11+
activation=tanh;
12+
state_modifiers=(),
13+
readout_activation=identity,
14+
include_collect::BoolType=True(),
15+
kwargs...)
16+
17+
esn_inp_size = in_dims + km_dims
18+
cell = StatefulLayer(ESNCell(esn_inp_size => res_dims, activation; kwargs...))
19+
mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
20+
Tuple(state_modifiers) : (state_modifiers,)
21+
mods = _wrap_layers(mods_tuple)
22+
ro = LinearReadout(res_dims + km_dims => out_dims, readout_activation;
23+
include_collect=static(include_collect))
24+
km_layer = km isa WrappedFunction ? km : WrappedFunction(km)
25+
return HybridESN(cell, km_layer, mods, ro)
26+
end
27+
28+
function initialparameters(rng::AbstractRNG, hesn::HybridESN)
29+
ps_cell = initialparameters(rng, hesn.cell)
30+
ps_km = initialparameters(rng, hesn.knowledge_model)
31+
ps_mods = map(l -> initialparameters(rng, l), hesn.states_modifiers) |> Tuple
32+
ps_ro = initialparameters(rng, hesn.readout)
33+
return (cell=ps_cell, knowledge_model=ps_km, states_modifiers=ps_mods, readout=ps_ro)
34+
end
35+
36+
function initialstates(rng::AbstractRNG, hesn::HybridESN)
37+
st_cell = initialstates(rng, hesn.cell)
38+
st_km = initialstates(rng, hesn.knowledge_model)
39+
st_mods = map(l -> initialstates(rng, l), hesn.states_modifiers) |> Tuple
40+
st_ro = initialstates(rng, hesn.readout)
41+
return (cell=st_cell, knowledge_model=st_km, states_modifiers=st_mods, readout=st_ro)
42+
end
43+
44+
function _partial_apply(hesn::HybridESN, inp, ps, st)
45+
k_t, st_km = hesn.knowledge_model(inp, ps.knowledge_model, st.knowledge_model)
46+
xin = vcat(k_t, inp)
47+
r, st_cell = apply(hesn.cell, xin, ps.cell, st.cell)
48+
rstar, st_mods = _apply_seq(hesn.states_modifiers, r, ps.states_modifiers, st.states_modifiers)
49+
feats = vcat(k_t, rstar)
50+
return feats, (cell=st_cell, states_modifiers=st_mods, knowledge_model=st_km)
51+
end
52+
53+
function (hesn::HybridESN)(inp, ps, st)
54+
feats, new_st = _partial_apply(hesn, inp, ps, st)
55+
y, st_ro = apply(hesn.readout, feats, ps.readout, st.readout)
56+
return y, merge(new_st, (readout=st_ro,))
57+
end

0 commit comments

Comments
 (0)