|
| 1 | +@concrete struct ReservoirComputer <: AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)} |
| 2 | + reservoir::Any |
| 3 | + states_modifiers::Any |
| 4 | + readout::Any |
| 5 | +end |
| 6 | + |
| 7 | +function initialparameters(rng::AbstractRNG, rc::ReservoirComputer) |
| 8 | + ps_res = initialparameters(rng, rc.reservoir) |
| 9 | + ps_mods = map(l -> initialparameters(rng, l), rc.states_modifiers) |> Tuple |
| 10 | + ps_ro = initialparameters(rng, rc.readout) |
| 11 | + return (reservoir = ps_res, states_modifiers = ps_mods, readout = ps_ro) |
| 12 | +end |
| 13 | + |
| 14 | +function initialstates(rng::AbstractRNG, rc::ReservoirComputer) |
| 15 | + st_res = initialstates(rng, rc.reservoir) |
| 16 | + st_mods = map(l -> initialstates(rng, l), rc.states_modifiers) |> Tuple |
| 17 | + st_ro = initialstates(rng, rc.readout) |
| 18 | + return (reservoir = st_res, states_modifiers = st_mods, readout = st_ro) |
| 19 | +end |
| 20 | + |
| 21 | +@inline function _apply_seq(layers::Tuple, inp, ps::Tuple, st::Tuple) |
| 22 | + new_st_parts = Vector{Any}(undef, length(layers)) |
| 23 | + for idx in eachindex(layers) |
| 24 | + inp, sti = apply(layers[idx], inp, ps[idx], st[idx]) |
| 25 | + new_st_parts[idx] = sti |
| 26 | + end |
| 27 | + return inp, tuple(new_st_parts...) |
| 28 | +end |
| 29 | + |
| 30 | +function _partial_apply(rc::ReservoirComputer, inp, ps, st) |
| 31 | + out, st_res = apply(rc.reservoir, inp, ps.reservoir, st.reservoir) |
| 32 | + out, |
| 33 | + st_mods = _apply_seq( |
| 34 | + rc.states_modifiers, out, ps.states_modifiers, st.states_modifiers) |
| 35 | + return out, (reservoir = st_res, states_modifiers = st_mods) |
| 36 | +end |
| 37 | + |
| 38 | +function (rc::AbstractReservoirComputer)(inp, ps, st) |
| 39 | + out, new_st = _partial_apply(rc, inp, ps, st) |
| 40 | + out, st_ro = apply(rc.readout, out, ps.readout, st.readout) |
| 41 | + return out, merge(new_st, (readout = st_ro,)) |
| 42 | +end |
| 43 | + |
| 44 | +function collectstates(rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st::NamedTuple) |
| 45 | + newst = st |
| 46 | + collected = Any[] |
| 47 | + for inp in eachcol(data) |
| 48 | + state_t, partial_st = _partial_apply(rc, inp, ps, newst) |
| 49 | + push!(collected, copy(state_t)) |
| 50 | + newst = merge(partial_st, (readout = newst.readout,)) |
| 51 | + end |
| 52 | + @assert !isempty(collected) |
| 53 | + states_raw = reduce(hcat, collected) |
| 54 | + states = eltype(data).(states_raw) |
| 55 | + return states, newst |
| 56 | +end |
| 57 | + |
| 58 | +_set_readout_weight(ps_readout::NamedTuple, wro) = merge(ps_readout, (; weight = wro)) |
| 59 | + |
| 60 | +function addreadout!(::AbstractReservoirComputer, output_matrix::AbstractMatrix, |
| 61 | + ps::NamedTuple, st::NamedTuple) |
| 62 | + @assert hasproperty(ps, :readout) |
| 63 | + new_readout = _set_readout_weight(ps.readout, output_matrix) |
| 64 | + return merge(ps, (readout = new_readout,)), st |
| 65 | +end |
| 66 | + |
| 67 | +@doc raw""" |
| 68 | + resetcarry!(rng, rc::ReservoirComputer, st; init_carry=nothing) |
| 69 | + resetcarry!(rng, rc::ReservoirComputer, ps, st; init_carry=nothing) |
| 70 | +
|
| 71 | +Reset (or set) the hidden-state carry of a model in the echo state network family. |
| 72 | +
|
| 73 | +If an existing carry is present in `st.cell.carry`, its leading dimension is used to |
| 74 | +infer the state size. Otherwise the reservoir output size is taken from |
| 75 | +`rc.reservoir.cell.out_dims`. When `init_carry=nothing`, the carry is cleared; the initializer |
| 76 | +from the struct construction will then be used. When a |
| 77 | +function is provided, it is called to create a new initial hidden state. |
| 78 | +
|
| 79 | +## Arguments |
| 80 | +
|
| 81 | +- `rng`: Random number generator (used if a new carry is sampled/created). |
| 82 | +- `rc`: A reservoir computing network model. |
| 83 | +- `st`: Current model states. |
| 84 | +- `ps`: Optional model parameters. Returned unchanged. |
| 85 | +
|
| 86 | +## Keyword arguments |
| 87 | +
|
| 88 | +- `init_carry`: Controls the initialization of the new carry. |
| 89 | + - `nothing` (default): remove/clear the carry (forces the cell to reinitialize |
| 90 | + from its own `init_state` on next use). |
| 91 | + - `f`: a function following standard from |
| 92 | + [WeightInitializers.jl](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers) |
| 93 | +
|
| 94 | +## Returns |
| 95 | +
|
| 96 | +- `resetcarry!(rng, rc, st; ...) -> st′`: |
| 97 | + Updated states with `st′.cell.carry` set to `nothing` or `(h0,)`. |
| 98 | +- `resetcarry!(rng, rc, ps, st; ...) -> (ps, st′)`: |
| 99 | + Same as above, but also returns the unchanged `ps` for convenience. |
| 100 | +
|
| 101 | +""" |
| 102 | +function resetcarry!(rng::AbstractRNG, rc::AbstractReservoirComputer, st; init_carry = nothing) |
| 103 | + carry = get(st.reservoir, :carry, nothing) |
| 104 | + if carry === nothing |
| 105 | + outd = rc.reservoir.cell.out_dims |
| 106 | + sz = outd |
| 107 | + else |
| 108 | + state = first(carry) |
| 109 | + sz = size(state, 1) |
| 110 | + end |
| 111 | + |
| 112 | + if init_carry === nothing |
| 113 | + new_state = nothing |
| 114 | + else |
| 115 | + new_state = init_carry(rng, sz, 1) |
| 116 | + new_state = (new_state,) |
| 117 | + end |
| 118 | + new_cell = merge(st.reservoir, (; carry = new_state)) |
| 119 | + return merge(st, (reservoir = new_cell,)) |
| 120 | +end |
| 121 | + |
| 122 | +function resetcarry!(rng::AbstractRNG, rc::AbstractReservoirComputer, |
| 123 | + ps, st; init_carry = nothing) |
| 124 | + return ps, resetcarry!(rng, rc, st; init_carry = init_carry) |
| 125 | +end |
0 commit comments