Skip to content

Commit 07e2020

Browse files
feat: improve DeepESN, add better typing to models
1 parent 602ecf0 commit 07e2020

File tree

6 files changed

+277
-111
lines changed

6 files changed

+277
-111
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ version = "0.11.4"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
10-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1312
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -35,7 +34,6 @@ CellularAutomata = "0.0.6"
3534
Compat = "4.16.0"
3635
ConcreteStructs = "0.2.3"
3736
DifferentialEquations = "7.16.1"
38-
Functors = "0.5.2"
3937
JET = "0.9.20"
4038
LIBSVM = "0.8"
4139
LinearAlgebra = "1.10"

src/ReservoirComputing.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@ using ConcreteStructs: @concrete
66
#using Functors
77
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
88
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
9-
setup, apply, replicate
9+
setup, apply, replicate
1010
import LuxCore: initialparameters, initialstates, statelength, outputsize
1111
using NNlib: fast_act, sigmoid
1212
using Random: Random, AbstractRNG, randperm
1313
using Static: StaticBool, StaticInt, StaticSymbol,
14-
True, False, static, known, dynamic, StaticInteger
14+
True, False, static, known, dynamic, StaticInteger
1515
using Reexport: Reexport, @reexport
1616
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1717
@reexport using WeightInitializers
1818
@reexport using LuxCore: setup, apply, initialparameters, initialstates
1919

20-
const BoolType = Union{StaticBool, Bool, Val{true}, Val{false}}
21-
const InputType = Tuple{<:AbstractArray, Tuple{<:AbstractArray}}
22-
const IntegerType = Union{Integer, StaticInteger}
20+
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
21+
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
22+
const IntegerType = Union{Integer,StaticInteger}
23+
24+
abstract type AbstractReservoirComputer{Fields} <: AbstractLuxContainerLayer{Fields} end
2325

2426
#@compat(public, (initialparameters)) #do I need to add intialstates/parameters in compat?
2527

@@ -36,26 +38,27 @@ include("train.jl")
3638
include("inits/inits_components.jl")
3739
include("inits/esn_inits.jl")
3840
#full models
41+
include("models/esn_utils.jl")
3942
include("models/esn.jl")
4043
include("models/deepesn.jl")
4144
include("models/hybridesn.jl")
4245
#extensions
4346
include("extensions/reca.jl")
4447

4548
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
46-
train!,
47-
predict, reset_carry
49+
train!,
50+
predict, resetcarry
4851
export SVMReadout
4952
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
5053
export StandardRidge
5154
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
52-
modified_lm, scaled_rand, weighted_init, weighted_minimal
55+
modified_lm, scaled_rand, weighted_init, weighted_minimal
5356
export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backward,
54-
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
55-
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
56-
selfloop_forward_connection, simple_cycle, true_double_cycle
57+
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
58+
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
59+
selfloop_forward_connection, simple_cycle, true_double_cycle
5760
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
58-
scale_radius!, self_loop!, simple_cycle!
61+
scale_radius!, self_loop!, simple_cycle!
5962
export train
6063
export ESN, HybridESN, KnowledgeModel, DeepESN
6164
#reca

src/models/deepesn.jl

Lines changed: 181 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,200 @@
1-
# --- helpers ---
2-
function _asvec(x, num_reservoirs::Int)
3-
if x === ()
4-
return ntuple(_ -> nothing, num_reservoirs)
5-
elseif x isa Tuple || x isa AbstractVector
6-
len = length(x)
7-
len == num_reservoirs && return Tuple(x)
8-
len == 1 && return ntuple(_ -> x[1], num_reservoirs)
9-
error("Expected length $num_reservoirs or 1 for per-layer argument, got $len")
10-
else
11-
return ntuple(_ -> x, num_reservoirs)
12-
end
1+
"""
2+
DeepESN(in_dims::Int,
3+
res_dims::AbstractVector{<:Int},
4+
out_dims,
5+
activation=tanh;
6+
leak_coefficient=1.0,
7+
init_reservoir=rand_sparse,
8+
init_input=weighted_init,
9+
init_bias=zeros32,
10+
init_state=randn32,
11+
use_bias=false,
12+
state_modifiers=(),
13+
readout_activation=identity)
14+
15+
Build a deep ESN: a stack of `StatefulLayer(ESNCell)` with optional per-layer
16+
state modifiers, followed by a final linear readout.
17+
"""
18+
@concrete struct DeepESN <: AbstractEchoStateNetwork{(:cells, :states_modifiers, :readout)}
19+
cells
20+
states_modifiers
21+
readout
1322
end
1423

15-
function DeepESN(in_dims::Int,
16-
res_dims::AbstractVector{<:Int},
17-
out_dims,
24+
function DeepESN(in_dims::IntegerType,
25+
res_dims::AbstractVector{<:IntegerType},
26+
out_dims::IntegerType,
1827
activation=tanh;
1928
leak_coefficient=1.0,
2029
init_reservoir=rand_sparse,
21-
init_input=weighted_init,
30+
init_input=scaled_rand,
2231
init_bias=zeros32,
2332
init_state=randn32,
2433
use_bias=false,
2534
state_modifiers=(),
2635
readout_activation=identity)
2736

28-
num_reservoirs = length(res_dims)
37+
n_layers = length(res_dims)
38+
acts = _asvec(activation, n_layers)
39+
leaks = _asvec(leak_coefficient, n_layers)
40+
ires = _asvec(init_reservoir, n_layers)
41+
iinp = _asvec(init_input, n_layers)
42+
ibias = _asvec(init_bias, n_layers)
43+
ist = _asvec(init_state, n_layers)
44+
ub = _asvec(use_bias, n_layers)
45+
mods0 = _asvec(state_modifiers, n_layers)
2946

30-
acts = _asvec(activation, num_reservoirs)
31-
leaksv = _asvec(leak_coefficient, num_reservoirs)
32-
inres = _asvec(init_reservoir, num_reservoirs)
33-
ininp = _asvec(init_input, num_reservoirs)
34-
inbias = _asvec(init_bias, num_reservoirs)
35-
inst = _asvec(init_state, num_reservoirs)
36-
ubias = _asvec(use_bias, num_reservoirs)
37-
mods = _asvec(state_modifiers, num_reservoirs)
47+
cells = Vector{Any}(undef, n_layers)
48+
states_modifiers = Vector{Any}(undef, n_layers)
3849

39-
layers = Any[]
4050
prev = in_dims
41-
for res in 1:num_reservoirs
42-
cell = ESNCell(prev => res_dims[res], acts[res];
43-
use_bias=static(ubias[res]),
44-
init_bias=inbias[res],
45-
init_reservoir=inres[res],
46-
init_input=ininp[res],
47-
init_state=inst[res],
48-
leak_coefficient=leaksv[res])
49-
50-
push!(layers, StatefulLayer(cell))
51-
if mods[res] !== nothing
52-
push!(layers, mods[res])
53-
end
54-
prev = res_dims[res]
51+
for idx in firstindex(res_dims):lastindex(res_dims)
52+
cell = ESNCell(prev => res_dims[idx], acts[idx];
53+
use_bias=static(ub[idx]),
54+
init_bias=ibias[idx],
55+
init_reservoir=ires[idx],
56+
init_input=iinp[idx],
57+
init_state=ist[idx],
58+
leak_coefficient=leaks[idx])
59+
cells[idx] = StatefulLayer(cell)
60+
states_modifiers[idx] = mods0[idx] === nothing ? nothing : _wrap_layer(mods0[idx])
61+
prev = res_dims[idx]
5562
end
63+
mods_per_layer = map(_coerce_layer_mods, states_modifiers) |> Tuple
5664
ro = LinearReadout(prev => out_dims, readout_activation)
57-
return ReservoirChain((layers..., ro)...)
65+
return DeepESN(Tuple(cells), mods_per_layer, ro)
66+
end
67+
68+
DeepESN(in_dims::Int, res_dim::Int, out_dims::Int; depth::Int=2, kwargs...) =
69+
DeepESN(in_dims, fill(res_dim, depth), out_dims; kwargs...)
70+
71+
function initialparameters(rng::AbstractRNG, desn::DeepESN)
72+
ps_cells = map(l -> initialparameters(rng, l), desn.cells) |> Tuple
73+
mods = desn.states_modifiers === nothing ? ntuple(_ -> (), length(desn.cells)) :
74+
desn.states_modifiers
75+
ps_mods = map(layer_mods ->
76+
(layer_mods === nothing ? () :
77+
map(l -> initialparameters(rng, l), layer_mods) |> Tuple),
78+
mods) |> Tuple
79+
80+
ps_ro = initialparameters(rng, desn.readout)
81+
return (cells=ps_cells, states_modifiers=ps_mods, readout=ps_ro)
82+
end
83+
84+
function initialstates(rng::AbstractRNG, desn::DeepESN)
85+
st_cells = map(l -> initialstates(rng, l), desn.cells) |> Tuple
86+
87+
mods = desn.states_modifiers === nothing ? ntuple(_ -> (), length(desn.cells)) :
88+
desn.states_modifiers
89+
90+
st_mods = map(layer_mods ->
91+
(layer_mods === nothing ? () :
92+
map(l -> initialstates(rng, l), layer_mods) |> Tuple),
93+
mods) |> Tuple
94+
95+
st_ro = initialstates(rng, desn.readout)
96+
return (cells=st_cells, states_modifiers=st_mods, readout=st_ro)
97+
end
98+
99+
function (desn::DeepESN)(inp, ps, st)
100+
inp_t = inp
101+
n_layers = length(desn.cells)
102+
new_cell_st = Vector{Any}(undef, n_layers)
103+
new_mods_st = Vector{Any}(undef, n_layers)
104+
for idx in firstindex(desn.cells):lastindex(desn.cells)
105+
inp_t, st_cell_i = apply(desn.cells[idx], inp_t, ps.cells[idx], st.cells[idx])
106+
new_cell_st[idx] = st_cell_i
107+
inp_t, st_mods_i = _apply_seq(desn.states_modifiers[idx], inp_t,
108+
ps.states_modifiers[idx], st.states_modifiers[idx])
109+
new_mods_st[idx] = st_mods_i
110+
end
111+
inp_t, st_ro = apply(desn.readout, inp_t, ps.readout, st.readout)
112+
113+
return inp_t, (;
114+
cells=tuple(new_cell_st...),
115+
states_modifiers=tuple(new_mods_st...),
116+
readout=st_ro,
117+
)
58118
end
59119

60-
function DeepESN(in_dims::Int, res_dims::Int, out_dims::Int; depth::Int=2, kwargs...)
61-
return DeepESN(in_dims, fill(res_dims, depth), out_dims; kwargs...)
120+
function resetcarry(rng::AbstractRNG, desn::DeepESN, st; init_carry=nothing)
121+
n_layers = length(desn.cells)
122+
123+
@inline function _layer_outdim(idx)
124+
st_i = st.cells[idx]
125+
if st_i.carry === nothing
126+
return desn.cells[idx].cell.out_dims
127+
else
128+
return size(first(st_i.carry), 1)
129+
end
130+
end
131+
132+
@inline function _init_for(idx)
133+
if init_carry === nothing
134+
return nothing
135+
elseif init_carry isa Function
136+
sz = _layer_outdim(idx)
137+
return (_asvec(init_carry(rng, sz)),)
138+
elseif init_carry isa Tuple || init_carry isa AbstractVector
139+
f = init_carry[idx]
140+
sz = _layer_outdim(idx)
141+
return f === nothing ? nothing : (_asvec(f(rng, sz)),)
142+
else
143+
throw(ArgumentError("init_carry must be nothing, a Function, or a Tuple/Vector of Functions"))
144+
end
145+
end
146+
147+
new_cells = ntuple(idx -> begin
148+
st_i = st.cells[idx]
149+
new_carry = _init_for(idx)
150+
merge(st_i, (; carry=new_carry))
151+
end, n_layers)
152+
153+
return (;
154+
cells=new_cells,
155+
states_modifiers=st.states_modifiers,
156+
readout=st.readout,
157+
)
158+
end
159+
160+
function collectstates(desn::DeepESN, data::AbstractMatrix, ps, st::NamedTuple)
161+
newst = st
162+
collected = Any[]
163+
n_layers = length(desn.cells)
164+
for inp in eachcol(data)
165+
inp_t = inp
166+
cell_st_parts = Vector{Any}(undef, n_layers)
167+
mods_st_parts = Vector{Any}(undef, n_layers)
168+
for idx in firstindex(desn.cells):lastindex(desn.cells)
169+
inp_t, st_cell_i = apply(desn.cells[idx], inp_t, ps.cells[idx], newst.cells[idx])
170+
cell_st_parts[idx] = st_cell_i
171+
inp_t, st_mods_i = _apply_seq(
172+
desn.states_modifiers[idx], inp_t,
173+
ps.states_modifiers[idx], newst.states_modifiers[idx]
174+
)
175+
mods_st_parts[idx] = st_mods_i
176+
end
177+
push!(collected, copy(inp_t))
178+
newst = (;
179+
cells=tuple(cell_st_parts...),
180+
states_modifiers=tuple(mods_st_parts...),
181+
readout=newst.readout,
182+
)
183+
end
184+
@assert !isempty(collected)
185+
states = eltype(data).(reduce(hcat, collected))
186+
187+
return states, newst
188+
end
189+
190+
collectstates(m::DeepESN, data::AbstractVector, ps, st::NamedTuple) =
191+
collectstates(m, reshape(data, :, 1), ps, st)
192+
193+
function addreadout!(::DeepESN, output_matrix::AbstractMatrix,
194+
ps::NamedTuple, st::NamedTuple)
195+
@assert hasproperty(ps, :readout)
196+
new_readout = _set_readout_weight(ps.readout, output_matrix)
197+
return (cells=ps.cells,
198+
states_modifiers=ps.states_modifiers,
199+
readout=new_readout), st
62200
end

0 commit comments

Comments
 (0)