|
1 | | -# --- helpers --- |
2 | | -function _asvec(x, num_reservoirs::Int) |
3 | | - if x === () |
4 | | - return ntuple(_ -> nothing, num_reservoirs) |
5 | | - elseif x isa Tuple || x isa AbstractVector |
6 | | - len = length(x) |
7 | | - len == num_reservoirs && return Tuple(x) |
8 | | - len == 1 && return ntuple(_ -> x[1], num_reservoirs) |
9 | | - error("Expected length $num_reservoirs or 1 for per-layer argument, got $len") |
10 | | - else |
11 | | - return ntuple(_ -> x, num_reservoirs) |
12 | | - end |
| 1 | +""" |
| 2 | + DeepESN(in_dims::Int, |
| 3 | + res_dims::AbstractVector{<:Int}, |
| 4 | + out_dims, |
| 5 | + activation=tanh; |
| 6 | + leak_coefficient=1.0, |
| 7 | + init_reservoir=rand_sparse, |
| 8 | + init_input=weighted_init, |
| 9 | + init_bias=zeros32, |
| 10 | + init_state=randn32, |
| 11 | + use_bias=false, |
| 12 | + state_modifiers=(), |
| 13 | + readout_activation=identity) |
| 14 | +
|
| 15 | +Build a deep ESN: a stack of `StatefulLayer(ESNCell)` with optional per-layer |
| 16 | +state modifiers, followed by a final linear readout. |
| 17 | +""" |
| 18 | +@concrete struct DeepESN <: AbstractEchoStateNetwork{(:cells, :states_modifiers, :readout)} |
| 19 | + cells |
| 20 | + states_modifiers |
| 21 | + readout |
13 | 22 | end |
14 | 23 |
|
15 | | -function DeepESN(in_dims::Int, |
16 | | - res_dims::AbstractVector{<:Int}, |
17 | | - out_dims, |
| 24 | +function DeepESN(in_dims::IntegerType, |
| 25 | + res_dims::AbstractVector{<:IntegerType}, |
| 26 | + out_dims::IntegerType, |
18 | 27 | activation=tanh; |
19 | 28 | leak_coefficient=1.0, |
20 | 29 | init_reservoir=rand_sparse, |
21 | | - init_input=weighted_init, |
| 30 | + init_input=scaled_rand, |
22 | 31 | init_bias=zeros32, |
23 | 32 | init_state=randn32, |
24 | 33 | use_bias=false, |
25 | 34 | state_modifiers=(), |
26 | 35 | readout_activation=identity) |
27 | 36 |
|
28 | | - num_reservoirs = length(res_dims) |
| 37 | + n_layers = length(res_dims) |
| 38 | + acts = _asvec(activation, n_layers) |
| 39 | + leaks = _asvec(leak_coefficient, n_layers) |
| 40 | + ires = _asvec(init_reservoir, n_layers) |
| 41 | + iinp = _asvec(init_input, n_layers) |
| 42 | + ibias = _asvec(init_bias, n_layers) |
| 43 | + ist = _asvec(init_state, n_layers) |
| 44 | + ub = _asvec(use_bias, n_layers) |
| 45 | + mods0 = _asvec(state_modifiers, n_layers) |
29 | 46 |
|
30 | | - acts = _asvec(activation, num_reservoirs) |
31 | | - leaksv = _asvec(leak_coefficient, num_reservoirs) |
32 | | - inres = _asvec(init_reservoir, num_reservoirs) |
33 | | - ininp = _asvec(init_input, num_reservoirs) |
34 | | - inbias = _asvec(init_bias, num_reservoirs) |
35 | | - inst = _asvec(init_state, num_reservoirs) |
36 | | - ubias = _asvec(use_bias, num_reservoirs) |
37 | | - mods = _asvec(state_modifiers, num_reservoirs) |
| 47 | + cells = Vector{Any}(undef, n_layers) |
| 48 | + states_modifiers = Vector{Any}(undef, n_layers) |
38 | 49 |
|
39 | | - layers = Any[] |
40 | 50 | prev = in_dims |
41 | | - for res in 1:num_reservoirs |
42 | | - cell = ESNCell(prev => res_dims[res], acts[res]; |
43 | | - use_bias=static(ubias[res]), |
44 | | - init_bias=inbias[res], |
45 | | - init_reservoir=inres[res], |
46 | | - init_input=ininp[res], |
47 | | - init_state=inst[res], |
48 | | - leak_coefficient=leaksv[res]) |
49 | | - |
50 | | - push!(layers, StatefulLayer(cell)) |
51 | | - if mods[res] !== nothing |
52 | | - push!(layers, mods[res]) |
53 | | - end |
54 | | - prev = res_dims[res] |
| 51 | + for idx in firstindex(res_dims):lastindex(res_dims) |
| 52 | + cell = ESNCell(prev => res_dims[idx], acts[idx]; |
| 53 | + use_bias=static(ub[idx]), |
| 54 | + init_bias=ibias[idx], |
| 55 | + init_reservoir=ires[idx], |
| 56 | + init_input=iinp[idx], |
| 57 | + init_state=ist[idx], |
| 58 | + leak_coefficient=leaks[idx]) |
| 59 | + cells[idx] = StatefulLayer(cell) |
| 60 | + states_modifiers[idx] = mods0[idx] === nothing ? nothing : _wrap_layer(mods0[idx]) |
| 61 | + prev = res_dims[idx] |
55 | 62 | end |
| 63 | + mods_per_layer = map(_coerce_layer_mods, states_modifiers) |> Tuple |
56 | 64 | ro = LinearReadout(prev => out_dims, readout_activation) |
57 | | - return ReservoirChain((layers..., ro)...) |
| 65 | + return DeepESN(Tuple(cells), mods_per_layer, ro) |
| 66 | +end |
| 67 | + |
| 68 | +DeepESN(in_dims::Int, res_dim::Int, out_dims::Int; depth::Int=2, kwargs...) = |
| 69 | + DeepESN(in_dims, fill(res_dim, depth), out_dims; kwargs...) |
| 70 | + |
| 71 | +function initialparameters(rng::AbstractRNG, desn::DeepESN) |
| 72 | + ps_cells = map(l -> initialparameters(rng, l), desn.cells) |> Tuple |
| 73 | + mods = desn.states_modifiers === nothing ? ntuple(_ -> (), length(desn.cells)) : |
| 74 | + desn.states_modifiers |
| 75 | + ps_mods = map(layer_mods -> |
| 76 | + (layer_mods === nothing ? () : |
| 77 | + map(l -> initialparameters(rng, l), layer_mods) |> Tuple), |
| 78 | + mods) |> Tuple |
| 79 | + |
| 80 | + ps_ro = initialparameters(rng, desn.readout) |
| 81 | + return (cells=ps_cells, states_modifiers=ps_mods, readout=ps_ro) |
| 82 | +end |
| 83 | + |
| 84 | +function initialstates(rng::AbstractRNG, desn::DeepESN) |
| 85 | + st_cells = map(l -> initialstates(rng, l), desn.cells) |> Tuple |
| 86 | + |
| 87 | + mods = desn.states_modifiers === nothing ? ntuple(_ -> (), length(desn.cells)) : |
| 88 | + desn.states_modifiers |
| 89 | + |
| 90 | + st_mods = map(layer_mods -> |
| 91 | + (layer_mods === nothing ? () : |
| 92 | + map(l -> initialstates(rng, l), layer_mods) |> Tuple), |
| 93 | + mods) |> Tuple |
| 94 | + |
| 95 | + st_ro = initialstates(rng, desn.readout) |
| 96 | + return (cells=st_cells, states_modifiers=st_mods, readout=st_ro) |
| 97 | +end |
| 98 | + |
| 99 | +function (desn::DeepESN)(inp, ps, st) |
| 100 | + inp_t = inp |
| 101 | + n_layers = length(desn.cells) |
| 102 | + new_cell_st = Vector{Any}(undef, n_layers) |
| 103 | + new_mods_st = Vector{Any}(undef, n_layers) |
| 104 | + for idx in firstindex(desn.cells):lastindex(desn.cells) |
| 105 | + inp_t, st_cell_i = apply(desn.cells[idx], inp_t, ps.cells[idx], st.cells[idx]) |
| 106 | + new_cell_st[idx] = st_cell_i |
| 107 | + inp_t, st_mods_i = _apply_seq(desn.states_modifiers[idx], inp_t, |
| 108 | + ps.states_modifiers[idx], st.states_modifiers[idx]) |
| 109 | + new_mods_st[idx] = st_mods_i |
| 110 | + end |
| 111 | + inp_t, st_ro = apply(desn.readout, inp_t, ps.readout, st.readout) |
| 112 | + |
| 113 | + return inp_t, (; |
| 114 | + cells=tuple(new_cell_st...), |
| 115 | + states_modifiers=tuple(new_mods_st...), |
| 116 | + readout=st_ro, |
| 117 | + ) |
58 | 118 | end |
59 | 119 |
|
60 | | -function DeepESN(in_dims::Int, res_dims::Int, out_dims::Int; depth::Int=2, kwargs...) |
61 | | - return DeepESN(in_dims, fill(res_dims, depth), out_dims; kwargs...) |
| 120 | +function resetcarry(rng::AbstractRNG, desn::DeepESN, st; init_carry=nothing) |
| 121 | + n_layers = length(desn.cells) |
| 122 | + |
| 123 | + @inline function _layer_outdim(idx) |
| 124 | + st_i = st.cells[idx] |
| 125 | + if st_i.carry === nothing |
| 126 | + return desn.cells[idx].cell.out_dims |
| 127 | + else |
| 128 | + return size(first(st_i.carry), 1) |
| 129 | + end |
| 130 | + end |
| 131 | + |
| 132 | + @inline function _init_for(idx) |
| 133 | + if init_carry === nothing |
| 134 | + return nothing |
| 135 | + elseif init_carry isa Function |
| 136 | + sz = _layer_outdim(idx) |
| 137 | + return (_asvec(init_carry(rng, sz)),) |
| 138 | + elseif init_carry isa Tuple || init_carry isa AbstractVector |
| 139 | + f = init_carry[idx] |
| 140 | + sz = _layer_outdim(idx) |
| 141 | + return f === nothing ? nothing : (_asvec(f(rng, sz)),) |
| 142 | + else |
| 143 | + throw(ArgumentError("init_carry must be nothing, a Function, or a Tuple/Vector of Functions")) |
| 144 | + end |
| 145 | + end |
| 146 | + |
| 147 | + new_cells = ntuple(idx -> begin |
| 148 | + st_i = st.cells[idx] |
| 149 | + new_carry = _init_for(idx) |
| 150 | + merge(st_i, (; carry=new_carry)) |
| 151 | + end, n_layers) |
| 152 | + |
| 153 | + return (; |
| 154 | + cells=new_cells, |
| 155 | + states_modifiers=st.states_modifiers, |
| 156 | + readout=st.readout, |
| 157 | + ) |
| 158 | +end |
| 159 | + |
| 160 | +function collectstates(desn::DeepESN, data::AbstractMatrix, ps, st::NamedTuple) |
| 161 | + newst = st |
| 162 | + collected = Any[] |
| 163 | + n_layers = length(desn.cells) |
| 164 | + for inp in eachcol(data) |
| 165 | + inp_t = inp |
| 166 | + cell_st_parts = Vector{Any}(undef, n_layers) |
| 167 | + mods_st_parts = Vector{Any}(undef, n_layers) |
| 168 | + for idx in firstindex(desn.cells):lastindex(desn.cells) |
| 169 | + inp_t, st_cell_i = apply(desn.cells[idx], inp_t, ps.cells[idx], newst.cells[idx]) |
| 170 | + cell_st_parts[idx] = st_cell_i |
| 171 | + inp_t, st_mods_i = _apply_seq( |
| 172 | + desn.states_modifiers[idx], inp_t, |
| 173 | + ps.states_modifiers[idx], newst.states_modifiers[idx] |
| 174 | + ) |
| 175 | + mods_st_parts[idx] = st_mods_i |
| 176 | + end |
| 177 | + push!(collected, copy(inp_t)) |
| 178 | + newst = (; |
| 179 | + cells=tuple(cell_st_parts...), |
| 180 | + states_modifiers=tuple(mods_st_parts...), |
| 181 | + readout=newst.readout, |
| 182 | + ) |
| 183 | + end |
| 184 | + @assert !isempty(collected) |
| 185 | + states = eltype(data).(reduce(hcat, collected)) |
| 186 | + |
| 187 | + return states, newst |
| 188 | +end |
| 189 | + |
| 190 | +collectstates(m::DeepESN, data::AbstractVector, ps, st::NamedTuple) = |
| 191 | + collectstates(m, reshape(data, :, 1), ps, st) |
| 192 | + |
| 193 | +function addreadout!(::DeepESN, output_matrix::AbstractMatrix, |
| 194 | + ps::NamedTuple, st::NamedTuple) |
| 195 | + @assert hasproperty(ps, :readout) |
| 196 | + new_readout = _set_readout_weight(ps.readout, output_matrix) |
| 197 | + return (cells=ps.cells, |
| 198 | + states_modifiers=ps.states_modifiers, |
| 199 | + readout=new_readout), st |
62 | 200 | end |
0 commit comments