Skip to content

Commit 0333d44

Browse files
feat: add ReservoirComputer
1 parent 819f1e8 commit 0333d44

File tree

7 files changed

+144
-146
lines changed

7 files changed

+144
-146
lines changed

docs/src/tutorials/reca.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ We are going to test the recall ability of the model, feeding the input data
4747
and investigating whether the predicted output equals the output data.
4848

4949
```@example reca
50-
_, st0 = setup(rng, reca) #reset the first ca state
50+
st0 = resetcarry!(rng, reca, st) #reset the first ca state
5151
pred_out, st = predict(reca, input, ps, st0)
5252
final_pred = convert(AbstractArray{Float32}, pred_out .> 0.5)
53-
5453
final_pred == output
5554
```

ext/RCCellularAutomataExt.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
module RCCellularAutomataExt
2-
using ReservoirComputing: RECA, RandomMapping, RandomMaps, AbstractInputEncoding,
3-
IntegerType, LinearReadout, ReservoirChain, StatefulLayer
4-
import ReservoirComputing: RECACell, RECA
2+
using ReservoirComputing: RECA, AbstractInputEncoding, ReservoirComputer,
3+
IntegerType, LinearReadout, StatefulLayer
4+
import ReservoirComputing: RECACell, RECA, RandomMapping, RandomMaps
55
using CellularAutomata
66
using Random: randperm
77

8-
function RandomMapping(; permutations = 8, expansion_size = 40)
9-
RandomMapping(permutations, expansion_size)
10-
end
11-
12-
function RandomMapping(permutations; expansion_size = 40)
13-
RandomMapping(permutations, expansion_size)
14-
end
15-
168
function create_encoding(rm::RandomMapping, in_dims::IntegerType, generations::IntegerType)
179
maps = init_maps(in_dims, rm.permutations, rm.expansion_size)
1810
states_size = generations * rm.expansion_size * rm.permutations
@@ -99,7 +91,7 @@ function RECA(in_dims::IntegerType,
9991

10092
ro = LinearReadout(rm.states_size => out_dims, readout_activation)
10193

102-
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
94+
return ReservoirComputer(StatefulLayer(cell), mods, ro)
10395
end
10496

10597
end #module

src/ReservoirComputing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1818

1919
#@compat(public, (initialparameters)) #do I need to add intialstates/parameters in compat?
2020

21+
#reservoir computers
2122
include("generics.jl")
23+
include("reservoircomputer.jl")
2224
#layers
2325
include("layers/basic.jl")
2426
include("layers/lux_layers.jl")
@@ -39,6 +41,7 @@ include("models/esn_hybrid.jl")
3941
#extensions
4042
include("extensions/reca.jl")
4143

44+
export ReservoirComputer
4245
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
4346
train!,
4447
predict, resetcarry!

src/extensions/reca.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ struct RandomMapping{I, T} <: AbstractInputEncoding
5656
expansion_size::T
5757
end
5858

59+
function RandomMapping(; permutations = 8, expansion_size = 40)
60+
RandomMapping(permutations, expansion_size)
61+
end
62+
63+
function RandomMapping(permutations; expansion_size = 40)
64+
RandomMapping(permutations, expansion_size)
65+
end
66+
5967
struct RandomMaps{T, E, G, M, S} <: AbstractEncodingData
6068
permutations::T
6169
expansion_size::E

src/models/esn.jl

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Composition:
6666
6767
## Parameters
6868
69-
- `cell` — parameters of the internal [`ESNCell`](@ref), including:
69+
- `reservoir` — parameters of the internal [`ESNCell`](@ref), including:
7070
- `input_matrix :: (res_dims × in_dims)` — `W_in`
7171
- `reservoir_matrix :: (res_dims × res_dims)` — `W_res`
7272
- `bias :: (res_dims,)` — present only if `use_bias=true`
@@ -80,19 +80,11 @@ Composition:
8080
8181
## States
8282
83-
Created by `initialstates(rng, esn)`:
84-
85-
- `cell` — states for the internal [`ESNCell`](@ref) (e.g. `rng` used to sample initial hidden states).
83+
- `reservoir` — states for the internal [`ESNCell`](@ref) (e.g. `rng` used to sample initial hidden states).
8684
- `states_modifiers` — a `Tuple` with states for each modifier layer.
8785
- `readout` — states for [`LinearReadout`](@ref).
8886
8987
"""
90-
@concrete struct ESN <: AbstractEchoStateNetwork{(:cell, :states_modifiers, :readout)}
91-
cell::Any
92-
states_modifiers::Any
93-
readout::Any
94-
end
95-
9688
function ESN(in_dims::IntegerType, res_dims::IntegerType,
9789
out_dims::IntegerType, activation = tanh;
9890
readout_activation = identity,
@@ -103,33 +95,5 @@ function ESN(in_dims::IntegerType, res_dims::IntegerType,
10395
Tuple(state_modifiers) : (state_modifiers,)
10496
mods = _wrap_layers(mods_tuple)
10597
ro = LinearReadout(res_dims => out_dims, readout_activation)
106-
return ESN(cell, mods, ro)
107-
end
108-
109-
function initialparameters(rng::AbstractRNG, esn::ESN)
110-
ps_cell = initialparameters(rng, esn.cell)
111-
ps_mods = map(l -> initialparameters(rng, l), esn.states_modifiers) |> Tuple
112-
ps_ro = initialparameters(rng, esn.readout)
113-
return (cell = ps_cell, states_modifiers = ps_mods, readout = ps_ro)
114-
end
115-
116-
function initialstates(rng::AbstractRNG, esn::ESN)
117-
st_cell = initialstates(rng, esn.cell)
118-
st_mods = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple
119-
st_ro = initialstates(rng, esn.readout)
120-
return (cell = st_cell, states_modifiers = st_mods, readout = st_ro)
121-
end
122-
123-
function _partial_apply(esn::ESN, inp, ps, st)
124-
out, st_cell = apply(esn.cell, inp, ps.cell, st.cell)
125-
out,
126-
st_mods = _apply_seq(
127-
esn.states_modifiers, out, ps.states_modifiers, st.states_modifiers)
128-
return out, (cell = st_cell, states_modifiers = st_mods)
129-
end
130-
131-
function (esn::ESN)(inp, ps, st)
132-
out, new_st = _partial_apply(esn, inp, ps, st)
133-
out, st_ro = apply(esn.readout, out, ps.readout, st.readout)
134-
return out, merge(new_st, (readout = st_ro,))
98+
return ReservoirComputer(cell, mods, ro)
13599
end

src/models/esn_generics.jl

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@ abstract type AbstractEchoStateNetwork{Fields} <: AbstractReservoirComputer{Fiel
33
_wrap_layer(x) = x isa Function ? WrappedFunction(x) : x
44
_wrap_layers(xs::Tuple) = map(_wrap_layer, xs)
55

6-
@inline function _apply_seq(layers::Tuple, inp, ps::Tuple, st::Tuple)
7-
new_st_parts = Vector{Any}(undef, length(layers))
8-
for idx in eachindex(layers)
9-
inp, sti = apply(layers[idx], inp, ps[idx], st[idx])
10-
new_st_parts[idx] = sti
11-
end
12-
return inp, tuple(new_st_parts...)
13-
end
14-
156
@inline function _fillvec(x, n::Integer)
167
v = Vector{typeof(x)}(undef, n)
178
@inbounds @simd for i in 1:n
@@ -56,87 +47,3 @@ function _coerce_layer_mods(x)
5647
x isa AbstractVector ? Tuple(x) :
5748
(x,)
5849
end
59-
60-
_set_readout_weight(ps_readout::NamedTuple, wro) = merge(ps_readout, (; weight = wro))
61-
62-
function collectstates(esn::AbstractEchoStateNetwork, data::AbstractMatrix, ps, st::NamedTuple)
63-
newst = st
64-
collected = Any[]
65-
for inp in eachcol(data)
66-
state_t, partial_st = _partial_apply(esn, inp, ps, newst)
67-
push!(collected, copy(state_t))
68-
newst = merge(partial_st, (readout = newst.readout,))
69-
end
70-
states = eltype(data).(reduce(hcat, collected))
71-
@assert !isempty(collected)
72-
states_raw = reduce(hcat, collected)
73-
states = eltype(data).(states_raw)
74-
return states, newst
75-
end
76-
77-
function addreadout!(::AbstractEchoStateNetwork, output_matrix::AbstractMatrix,
78-
ps::NamedTuple, st::NamedTuple)
79-
@assert hasproperty(ps, :readout)
80-
new_readout = _set_readout_weight(ps.readout, output_matrix)
81-
return merge(ps, (readout = new_readout,)), st
82-
end
83-
84-
@doc raw"""
85-
resetcarry!(rng, esn::AbstractEchoStateNetwork, st; init_carry=nothing)
86-
resetcarry!(rng, esn::AbstractEchoStateNetwork, ps, st; init_carry=nothing)
87-
88-
Reset (or set) the hidden-state carry of a model in the echo state network family.
89-
90-
If an existing carry is present in `st.cell.carry`, its leading dimension is used to
91-
infer the state size. Otherwise the reservoir output size is taken from
92-
`esn.cell.cell.out_dims`. When `init_carry=nothing`, the carry is cleared; the initializer
93-
from the struct construction will then be used. When a
94-
function is provided, it is called to create a new initial hidden state.
95-
96-
## Arguments
97-
98-
- `rng`: Random number generator (used if a new carry is sampled/created).
99-
- `esn`: An echo state network model.
100-
- `st`: Current model states.
101-
- `ps`: Optional model parameters. Returned unchanged.
102-
103-
## Keyword arguments
104-
105-
- `init_carry`: Controls the initialization of the new carry.
106-
- `nothing` (default): remove/clear the carry (forces the cell to reinitialize
107-
from its own `init_state` on next use).
108-
- `f`: a function called as `f(rng, sz, batch)`, following standard from
109-
[WeightInitializers.jl](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers)
110-
111-
## Returns
112-
113-
- `resetcarry!(rng, esn, st; ...) -> st′`:
114-
Updated states with `st′.cell.carry` set to `nothing` or `(h0,)`.
115-
- `resetcarry!(rng, esn, ps, st; ...) -> (ps, st′)`:
116-
Same as above, but also returns the unchanged `ps` for convenience.
117-
118-
"""
119-
function resetcarry!(rng::AbstractRNG, esn::AbstractEchoStateNetwork, st; init_carry = nothing)
120-
carry = get(st.cell, :carry, nothing)
121-
if carry === nothing
122-
outd = esn.cell.cell.out_dims
123-
sz = outd
124-
else
125-
state = first(carry)
126-
sz = size(state, 1)
127-
end
128-
129-
if init_carry === nothing
130-
new_state = nothing
131-
else
132-
new_state = init_carry(rng, sz, 1)
133-
new_state = (new_state,)
134-
end
135-
new_cell = merge(st.cell, (; carry = new_state))
136-
return merge(st, (cell = new_cell,))
137-
end
138-
139-
function resetcarry!(rng::AbstractRNG, esn::AbstractEchoStateNetwork,
140-
ps, st; init_carry = nothing)
141-
return ps, resetcarry!(rng, esn, st; init_carry = init_carry)
142-
end

src/reservoircomputer.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
@concrete struct ReservoirComputer <: AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)}
2+
reservoir::Any
3+
states_modifiers::Any
4+
readout::Any
5+
end
6+
7+
function initialparameters(rng::AbstractRNG, rc::ReservoirComputer)
8+
ps_res = initialparameters(rng, rc.reservoir)
9+
ps_mods = map(l -> initialparameters(rng, l), rc.states_modifiers) |> Tuple
10+
ps_ro = initialparameters(rng, rc.readout)
11+
return (reservoir = ps_res, states_modifiers = ps_mods, readout = ps_ro)
12+
end
13+
14+
function initialstates(rng::AbstractRNG, rc::ReservoirComputer)
15+
st_res = initialstates(rng, rc.reservoir)
16+
st_mods = map(l -> initialstates(rng, l), rc.states_modifiers) |> Tuple
17+
st_ro = initialstates(rng, rc.readout)
18+
return (reservoir = st_res, states_modifiers = st_mods, readout = st_ro)
19+
end
20+
21+
@inline function _apply_seq(layers::Tuple, inp, ps::Tuple, st::Tuple)
22+
new_st_parts = Vector{Any}(undef, length(layers))
23+
for idx in eachindex(layers)
24+
inp, sti = apply(layers[idx], inp, ps[idx], st[idx])
25+
new_st_parts[idx] = sti
26+
end
27+
return inp, tuple(new_st_parts...)
28+
end
29+
30+
function _partial_apply(rc::ReservoirComputer, inp, ps, st)
31+
out, st_res = apply(rc.reservoir, inp, ps.reservoir, st.reservoir)
32+
out,
33+
st_mods = _apply_seq(
34+
rc.states_modifiers, out, ps.states_modifiers, st.states_modifiers)
35+
return out, (reservoir = st_res, states_modifiers = st_mods)
36+
end
37+
38+
function (rc::AbstractReservoirComputer)(inp, ps, st)
39+
out, new_st = _partial_apply(rc, inp, ps, st)
40+
out, st_ro = apply(rc.readout, out, ps.readout, st.readout)
41+
return out, merge(new_st, (readout = st_ro,))
42+
end
43+
44+
function collectstates(rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st::NamedTuple)
45+
newst = st
46+
collected = Any[]
47+
for inp in eachcol(data)
48+
state_t, partial_st = _partial_apply(rc, inp, ps, newst)
49+
push!(collected, copy(state_t))
50+
newst = merge(partial_st, (readout = newst.readout,))
51+
end
52+
@assert !isempty(collected)
53+
states_raw = reduce(hcat, collected)
54+
states = eltype(data).(states_raw)
55+
return states, newst
56+
end
57+
58+
_set_readout_weight(ps_readout::NamedTuple, wro) = merge(ps_readout, (; weight = wro))
59+
60+
function addreadout!(::AbstractReservoirComputer, output_matrix::AbstractMatrix,
61+
ps::NamedTuple, st::NamedTuple)
62+
@assert hasproperty(ps, :readout)
63+
new_readout = _set_readout_weight(ps.readout, output_matrix)
64+
return merge(ps, (readout = new_readout,)), st
65+
end
66+
67+
@doc raw"""
68+
resetcarry!(rng, rc::ReservoirComputer, st; init_carry=nothing)
69+
resetcarry!(rng, rc::ReservoirComputer, ps, st; init_carry=nothing)
70+
71+
Reset (or set) the hidden-state carry of a model in the echo state network family.
72+
73+
If an existing carry is present in `st.cell.carry`, its leading dimension is used to
74+
infer the state size. Otherwise the reservoir output size is taken from
75+
`rc.reservoir.cell.out_dims`. When `init_carry=nothing`, the carry is cleared; the initializer
76+
from the struct construction will then be used. When a
77+
function is provided, it is called to create a new initial hidden state.
78+
79+
## Arguments
80+
81+
- `rng`: Random number generator (used if a new carry is sampled/created).
82+
- `rc`: A reservoir computing network model.
83+
- `st`: Current model states.
84+
- `ps`: Optional model parameters. Returned unchanged.
85+
86+
## Keyword arguments
87+
88+
- `init_carry`: Controls the initialization of the new carry.
89+
- `nothing` (default): remove/clear the carry (forces the cell to reinitialize
90+
from its own `init_state` on next use).
91+
- `f`: a function following standard from
92+
[WeightInitializers.jl](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers)
93+
94+
## Returns
95+
96+
- `resetcarry!(rng, rc, st; ...) -> st′`:
97+
Updated states with `st′.cell.carry` set to `nothing` or `(h0,)`.
98+
- `resetcarry!(rng, rc, ps, st; ...) -> (ps, st′)`:
99+
Same as above, but also returns the unchanged `ps` for convenience.
100+
101+
"""
102+
function resetcarry!(rng::AbstractRNG, rc::AbstractReservoirComputer, st; init_carry = nothing)
103+
carry = get(st.reservoir, :carry, nothing)
104+
if carry === nothing
105+
outd = rc.reservoir.cell.out_dims
106+
sz = outd
107+
else
108+
state = first(carry)
109+
sz = size(state, 1)
110+
end
111+
112+
if init_carry === nothing
113+
new_state = nothing
114+
else
115+
new_state = init_carry(rng, sz, 1)
116+
new_state = (new_state,)
117+
end
118+
new_cell = merge(st.reservoir, (; carry = new_state))
119+
return merge(st, (reservoir = new_cell,))
120+
end
121+
122+
function resetcarry!(rng::AbstractRNG, rc::AbstractReservoirComputer,
123+
ps, st; init_carry = nothing)
124+
return ps, resetcarry!(rng, rc, st; init_carry = init_carry)
125+
end

0 commit comments

Comments
 (0)