11@concrete struct ESN <: AbstractLuxContainerLayer{(:cell, :states_modifiers, :readout)}
2- cell
3- states_modifiers
4- readout
2+ cell:: Any
3+ states_modifiers:: Any
4+ readout:: Any
55end
66
77_wrap_layer (x) = x isa Function ? WrappedFunction (x) : x
88_wrap_layers (xs:: Tuple ) = map (_wrap_layer, xs)
99
10- function ESN (in_dims:: IntegerType , res_dims:: IntegerType , out_dims:: IntegerType , activation= tanh;
11- readout_activation= identity,
12- state_modifiers= (),
13- kwargs... )
10+ function ESN (in_dims:: IntegerType , res_dims:: IntegerType ,
11+ out_dims:: IntegerType , activation = tanh;
12+ readout_activation = identity,
13+ state_modifiers = (),
14+ kwargs... )
1415 cell = StatefulLayer (ESNCell (in_dims => res_dims, activation; kwargs... ))
1516 mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
1617 Tuple (state_modifiers) : (state_modifiers,)
@@ -23,128 +24,84 @@ function initialparameters(rng::AbstractRNG, esn::ESN)
2324 ps_cell = initialparameters (rng, esn. cell)
2425 ps_mods = map (l -> initialparameters (rng, l), esn. states_modifiers) |> Tuple
2526 ps_ro = initialparameters (rng, esn. readout)
26- return (cell= ps_cell, states_modifiers= ps_mods, readout= ps_ro)
27+ return (cell = ps_cell, states_modifiers = ps_mods, readout = ps_ro)
2728end
2829
2930function initialstates (rng:: AbstractRNG , esn:: ESN )
3031 st_cell = initialstates (rng, esn. cell)
3132 st_mods = map (l -> initialstates (rng, l), esn. states_modifiers) |> Tuple
3233 st_ro = initialstates (rng, esn. readout)
33- return (cell= st_cell, states_modifiers= st_mods, readout= st_ro)
34+ return (cell = st_cell, states_modifiers = st_mods, readout = st_ro)
3435end
3536
36- @inline function _apply_seq (layers:: Tuple , x, ps:: Tuple , st:: Tuple )
37- n = length (layers)
38- new_st_parts = Vector {Any} (undef, n)
39- @inbounds for i in 1 : n
40- x, sti = apply (layers[i], x, ps[i], st[i])
41- new_st_parts[i] = sti
37+ @inline function _apply_seq (layers:: Tuple , inp, ps:: Tuple , st:: Tuple )
38+ new_st_parts = Vector {Any} (undef, length (layers))
39+ for idx in eachindex (layers)
40+ inp, sti = apply (layers[idx], inp, ps[idx], st[idx])
41+ new_st_parts[idx] = sti
4242 end
43- return x , tuple (new_st_parts... )
43+ return inp , tuple (new_st_parts... )
4444end
4545
46- function (m:: ESN )(x, ps, st)
47- y, st_cell = apply (m. cell, x, ps. cell, st. cell)
48- y, st_mods = _apply_seq (m. states_modifiers, y, ps. states_modifiers, st. states_modifiers)
49- y, st_ro = apply (m. readout, y, ps. readout, st. readout)
50- return y, (cell= st_cell, states_modifiers= st_mods, readout= st_ro)
46+ function (esn:: ESN )(inp:: AbstractVector , ps, st)
47+ out, st_cell = apply (esn. cell, inp, ps. cell, st. cell)
48+ out, st_mods = _apply_seq (
49+ esn. states_modifiers, out, ps. states_modifiers, st. states_modifiers)
50+ out, st_ro = apply (esn. readout, out, ps. readout, st. readout)
51+ return out, (cell = st_cell, states_modifiers = st_mods, readout = st_ro)
5152end
5253
53- function reset_carry (esn:: ESN , st; mode= :zeros , value= nothing , rng= nothing )
54- # Find current carry & infer shape/type
55- c = get (st. cell, :carry , nothing )
56- if c === nothing
54+ function reset_carry (rng:: AbstractRNG , esn:: ESN , ps, st; init_carry = nothing )
55+ carry = get (st. cell, :carry , nothing )
56+ if carry === nothing
5757 outd = esn. cell. cell. out_dims
58- T = Float32
59- sz = (outd, 1 )
58+ sz = outd
6059 else
61- h = c[1 ] # carry is usually a 1-tuple (h,)
62- T = eltype (h)
63- sz = size (h)
60+ state = first (carry)
61+ sz = size (state, 1 )
6462 end
6563
66- new_h = begin
67- if mode === :zeros
68- zeros (T, sz)
69- elseif mode === :randn
70- rng = rng === nothing ? Random. default_rng () : rng
71- randn (rng, T, sz... )
72- elseif mode === :value
73- @assert value != = nothing " Provide `value=` when mode=:value"
74- fill (T (value), sz)
75- else
76- error (" Unknown mode=$(mode) . Use :zeros, :randn, or :value." )
77- end
64+ if init_carry === nothing
65+ new_state = nothing
66+ else
67+ new_state = init_carry (rng, sz, 1 )
68+ new_state = (new_state,)
7869 end
7970
80- new_cell = merge (st. cell, (; carry= (new_h,)))
81- return (cell= new_cell, states_modifiers= st. states_modifiers, readout= st. readout)
71+ new_cell = merge (st. cell, (; carry = new_state))
72+ return ps,
73+ (cell = new_cell, states_modifiers = st. states_modifiers, readout = st. readout)
8274end
8375
84- _set_readout_weight (ps_readout:: NamedTuple , W) = merge (ps_readout, (; weight= W))
85-
86- function train! (m:: ESN , train_data:: AbstractMatrix , target_data:: AbstractMatrix ,
87- ps, st, train_method= StandardRidge (0.0 );
88- washout:: Int = 0 , return_states:: Bool = false )
76+ _set_readout_weight (ps_readout:: NamedTuple , wro) = merge (ps_readout, (; weight = wro))
8977
78+ function collectstates (esn:: ESN , data:: AbstractMatrix , ps, st:: NamedTuple )
9079 newst = st
91- collected = Vector {Any} (undef, size (train_data, 2 ))
92- @inbounds for (t, x) in enumerate (eachcol (train_data))
93- y, st_cell = apply (m. cell, x, ps. cell, newst. cell)
94- y, st_mods = _apply_seq (m. states_modifiers, y, ps. states_modifiers, newst. states_modifiers)
95- collected[t] = copy (y)
96- newst = (cell= st_cell, states_modifiers= st_mods, readout= newst. readout)
80+ collected = Any[]
81+ for inp in eachcol (data)
82+ cell_y, st_cell = apply (esn. cell, inp, ps. cell, newst. cell)
83+ state_t, st_mods = _apply_seq (
84+ esn. states_modifiers, cell_y, ps. states_modifiers, newst. states_modifiers)
85+ push! (collected, copy (state_t))
86+ newst = (cell = st_cell, states_modifiers = st_mods, readout = newst. readout)
9787 end
98- states = eltype (train_data).(reduce (hcat, collected))
99-
100- states_wo, targets_wo =
101- washout > 0 ? _apply_washout (states, target_data, washout) : (states, target_data)
102-
103- W = train (train_method, states_wo, targets_wo)
104- ps2 = (cell= ps. cell,
105- states_modifiers= ps. states_modifiers,
106- readout= _set_readout_weight (ps. readout, W))
107-
108- return return_states ? ((ps2, newst), states_wo) : (ps2, newst)
88+ states = eltype (data).(reduce (hcat, collected))
89+ @assert ! isempty (collected)
90+ states_raw = reduce (hcat, collected)
91+ states = eltype (data).(states_raw)
92+ return states, newst
10993end
11094
111- _basefuncstr (x) = sprint (show, x)
95+ function train! (esn:: ESN , train_data:: AbstractMatrix , target_data:: AbstractMatrix ,
96+ ps, st, train_method = StandardRidge (0.0 );
97+ washout:: Int = 0 , return_states:: Bool = false )
98+ states, newst = collectstates (esn, train_data, ps, st)
99+ states_wo, targets_wo = washout > 0 ? _apply_washout (states, target_data, washout) :
100+ (states, target_data)
101+ wro = train (train_method, states_wo, targets_wo)
102+ ps2 = (cell = ps. cell,
103+ states_modifiers = ps. states_modifiers,
104+ readout = _set_readout_weight (ps. readout, wro))
112105
113- _getflag (x, sym:: Symbol , default= false ) = begin
114- v = known (getproperty (x, Val (sym)))
115- v === nothing ? default : v
116- end
117-
118- function Base. show (io:: IO , :: MIME"text/plain" , rc:: ReservoirChain )
119- L = collect (pairs (rc. layers))
120- if ! isempty (L) && (L[1 ][2 ] isa StatefulLayer) && (L[end ][2 ] isa LinearReadout)
121- sl = L[1 ][2 ]
122- ro = L[end ][2 ]
123- if sl. cell isa ESNCell
124- esn = sl. cell
125- mods = (length (L) > 2 ) ? map (x -> _basefuncstr (x[2 ]), L[2 : end - 1 ]) : String[]
126- print (io, " ESN($(esn. in_dims) => $(esn. out_dims) ; " ,
127- " activation=" , esn. activation,
128- " , leak=" , esn. leak_coefficient,
129- " , readout=" , ro. activation)
130- ic = _getflag (ro, :include_collect , false )
131- ic && print (io, " , include_collect=true" )
132- if ! _getflag (esn, :use_bias , false )
133- print (io, " , use_bias=false" )
134- end
135- if ! isempty (mods)
136- print (io, " , modifiers=[" , join (mods, " , " ), " ]" )
137- end
138- print (io, " )" )
139- return
140- end
141- end
142- strs = map (x -> _basefuncstr (x[2 ]), L)
143- if length (strs) <= 2
144- print (io, " ReservoirChain(" , join (strs, " , " ), " )" )
145- else
146- print (io, " ReservoirChain(\n " , join (strs, " ,\n " ), " \n )" )
147- end
106+ return return_states ? ((ps2, newst), states_wo) : (ps2, newst)
148107end
149-
150- Base. show (io:: IO , rc:: ReservoirChain ) = show (io, MIME " text/plain" (), rc)
0 commit comments