|
| 1 | +@concrete struct ESN <: AbstractLuxContainerLayer{(:cell, :states_modifiers, :readout)} |
| 2 | + cell |
| 3 | + states_modifiers |
| 4 | + readout |
| 5 | +end |
| 6 | + |
| 7 | +_wrap_layer(x) = x isa Function ? WrappedFunction(x) : x |
| 8 | +_wrap_layers(xs::Tuple) = map(_wrap_layer, xs) |
| 9 | + |
1 | 10 | function ESN(in_dims::IntegerType, res_dims::IntegerType, out_dims::IntegerType, activation=tanh; |
2 | 11 | readout_activation=identity, |
3 | 12 | state_modifiers=(), |
4 | 13 | kwargs...) |
5 | | - cell = ESNCell(in_dims => res_dims, activation; kwargs...) |
6 | | - mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ? |
7 | | - Tuple(state_modifiers) : (state_modifiers,) |
| 14 | + cell = StatefulLayer(ESNCell(in_dims => res_dims, activation; kwargs...)) |
| 15 | + mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ? |
| 16 | + Tuple(state_modifiers) : (state_modifiers,) |
| 17 | + mods = _wrap_layers(mods_tuple) |
8 | 18 | ro = LinearReadout(res_dims => out_dims, readout_activation) |
9 | | - return ReservoirChain((StatefulLayer(cell), mods..., ro)...) |
| 19 | + return ESN(cell, mods, ro) |
| 20 | +end |
| 21 | + |
| 22 | +function initialparameters(rng::AbstractRNG, esn::ESN) |
| 23 | + ps_cell = initialparameters(rng, esn.cell) |
| 24 | + ps_mods = map(l -> initialparameters(rng, l), esn.states_modifiers) |> Tuple |
| 25 | + ps_ro = initialparameters(rng, esn.readout) |
| 26 | + return (cell=ps_cell, states_modifiers=ps_mods, readout=ps_ro) |
| 27 | +end |
| 28 | + |
| 29 | +function initialstates(rng::AbstractRNG, esn::ESN) |
| 30 | + st_cell = initialstates(rng, esn.cell) |
| 31 | + st_mods = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple |
| 32 | + st_ro = initialstates(rng, esn.readout) |
| 33 | + return (cell=st_cell, states_modifiers=st_mods, readout=st_ro) |
| 34 | +end |
| 35 | + |
| 36 | +@inline function _apply_seq(layers::Tuple, x, ps::Tuple, st::Tuple) |
| 37 | + n = length(layers) |
| 38 | + new_st_parts = Vector{Any}(undef, n) |
| 39 | + @inbounds for i in 1:n |
| 40 | + x, sti = apply(layers[i], x, ps[i], st[i]) |
| 41 | + new_st_parts[i] = sti |
| 42 | + end |
| 43 | + return x, tuple(new_st_parts...) |
| 44 | +end |
| 45 | + |
| 46 | +function (m::ESN)(x, ps, st) |
| 47 | + y, st_cell = apply(m.cell, x, ps.cell, st.cell) |
| 48 | + y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, st.states_modifiers) |
| 49 | + y, st_ro = apply(m.readout, y, ps.readout, st.readout) |
| 50 | + return y, (cell=st_cell, states_modifiers=st_mods, readout=st_ro) |
| 51 | +end |
| 52 | + |
| 53 | +function reset_carry(esn::ESN, st; mode=:zeros, value=nothing, rng=nothing) |
| 54 | + # Find current carry & infer shape/type |
| 55 | + c = get(st.cell, :carry, nothing) |
| 56 | + if c === nothing |
| 57 | + outd = esn.cell.cell.out_dims |
| 58 | + T = Float32 |
| 59 | + sz = (outd, 1) |
| 60 | + else |
| 61 | + h = c[1] # carry is usually a 1-tuple (h,) |
| 62 | + T = eltype(h) |
| 63 | + sz = size(h) |
| 64 | + end |
| 65 | + |
| 66 | + new_h = begin |
| 67 | + if mode === :zeros |
| 68 | + zeros(T, sz) |
| 69 | + elseif mode === :randn |
| 70 | + rng = rng === nothing ? Random.default_rng() : rng |
| 71 | + randn(rng, T, sz...) |
| 72 | + elseif mode === :value |
| 73 | + @assert value !== nothing "Provide `value=` when mode=:value" |
| 74 | + fill(T(value), sz) |
| 75 | + else |
| 76 | + error("Unknown mode=$(mode). Use :zeros, :randn, or :value.") |
| 77 | + end |
| 78 | + end |
| 79 | + |
| 80 | + new_cell = merge(st.cell, (; carry=(new_h,))) |
| 81 | + return (cell=new_cell, states_modifiers=st.states_modifiers, readout=st.readout) |
| 82 | +end |
| 83 | + |
| 84 | +_set_readout_weight(ps_readout::NamedTuple, W) = merge(ps_readout, (; weight=W)) |
| 85 | + |
| 86 | +function train!(m::ESN, train_data::AbstractMatrix, target_data::AbstractMatrix, |
| 87 | + ps, st, train_method=StandardRidge(0.0); |
| 88 | + washout::Int=0, return_states::Bool=false) |
| 89 | + |
| 90 | + newst = st |
| 91 | + collected = Vector{Any}(undef, size(train_data, 2)) |
| 92 | + @inbounds for (t, x) in enumerate(eachcol(train_data)) |
| 93 | + y, st_cell = apply(m.cell, x, ps.cell, newst.cell) |
| 94 | + y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, newst.states_modifiers) |
| 95 | + collected[t] = copy(y) |
| 96 | + newst = (cell=st_cell, states_modifiers=st_mods, readout=newst.readout) |
| 97 | + end |
| 98 | + states = eltype(train_data).(reduce(hcat, collected)) |
| 99 | + |
| 100 | + states_wo, targets_wo = |
| 101 | + washout > 0 ? _apply_washout(states, target_data, washout) : (states, target_data) |
| 102 | + |
| 103 | + W = train(train_method, states_wo, targets_wo) |
| 104 | + ps2 = (cell=ps.cell, |
| 105 | + states_modifiers=ps.states_modifiers, |
| 106 | + readout=_set_readout_weight(ps.readout, W)) |
| 107 | + |
| 108 | + return return_states ? ((ps2, newst), states_wo) : (ps2, newst) |
10 | 109 | end |
11 | 110 |
|
12 | 111 | _basefuncstr(x) = sprint(show, x) |
|
0 commit comments