Skip to content

Commit 788e066

Browse files
refactor: align RECA to new apis
1 parent cff8f0f commit 788e066

File tree

4 files changed

+207
-83
lines changed

4 files changed

+207
-83
lines changed

docs/src/reca_tutorials/reca.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,37 @@ output = readdlm("./5bitoutput.txt", ',', Float64)
1818
To use a ReCA model, it is necessary to define the rule one intends to use. To do so, ReservoirComputing.jl leverages [CellularAutomata.jl](https://github.com/MartinuzziFrancesco/CellularAutomata.jl) that needs to be called as well to define the `RECA` struct:
1919

2020
```@example reca
21-
using ReservoirComputing, CellularAutomata
21+
using ReservoirComputing, CellularAutomata, Random
22+
Random.seed!(42)
23+
rng = MersenneTwister(17)
2224
2325
ca = DCA(90)
2426
```
2527

2628
To define the ReCA model, it suffices to call:
2729

2830
```@example reca
29-
reca = RECA(input, ca;
30-
generations=16,
31-
input_encoding=RandomMapping(16, 40))
31+
reca = RECA(4, 4, DCA(90);
32+
generations=16,
33+
input_encoding=RandomMapping(16, 40))
34+
ps, st = setup(rng, reca)
3235
```
33-
3436
After this, the training can be performed with the chosen method.
3537

3638
```@example reca
37-
output_layer = train(reca, output, StandardRidge(0.00001))
39+
ps, st = train!(reca, input, output, ps, st, StandardRidge(0.00001))
3840
```
3941

40-
The prediction in this case will be a `Predictive()` with the input data equal to the training data. In addition, to test the 5 bit memory task, a conversion from Float to Bool is necessary (at the moment, we are aware of a bug that doesn't allow boolean input data to the RECA models):
42+
We are going to test the recall ability of the model, feeding the input data
43+
and investigating wether the predicted output equals the output data.
44+
In addition, to test the 5 bit memory task, a conversion from Float to Bool
45+
is necessary (at the moment, we are aware of a bug that doesn't allow boolean
46+
input data to the RECA models):
4147

4248
```@example reca
43-
prediction = reca(Predictive(input), output_layer)
49+
_, st0 = setup(rng, reca) #reset the first ca state
50+
pred_out, st = predict(reca, input, ps, st0)
4451
final_pred = convert(AbstractArray{Float32}, prediction .> 0.5)
4552
4653
final_pred == output
4754
```
48-

ext/RCCellularAutomataExt.jl

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,62 @@
11
module RCCellularAutomataExt
2-
using ReservoirComputing: RECA, RandomMapping, RandomMaps
3-
import ReservoirComputing: train, next_state_prediction!, AbstractOutputLayer, NLADefault,
4-
StandardStates, obtain_prediction
2+
using ReservoirComputing: RECA, RandomMapping, RandomMaps, AbstractInputEncoding,
3+
IntegerType, Readout, ReservoirChain, StatefulLayer
4+
import ReservoirComputing: RECACell, RECA
55
using CellularAutomata
66
using Random: randperm
77

8-
function RECA(train_data,
9-
automata;
10-
generations = 8,
11-
input_encoding = RandomMapping(),
12-
nla_type = NLADefault(),
13-
states_type = StandardStates())
14-
in_size = size(train_data, 1)
15-
#res_size = obtain_res_size(input_encoding, generations)
16-
state_encoding = create_encoding(input_encoding, train_data, generations)
17-
states = reca_create_states(state_encoding, automata, train_data)
18-
19-
return RECA(train_data, automata, state_encoding, nla_type, states, states_type)
8+
function (reca::RECACell)((inp, (ca_prev,)), ps, st::NamedTuple)
9+
rm = reca.enc
10+
T = eltype(inp)
11+
ca0 = T.(encoding(rm, inp, T.(ca_prev)))
12+
ca = CellularAutomaton(reca.automaton, ca0, rm.generations + 1)
13+
evo = ca.evolution
14+
feat2T = evo[2:end, :]
15+
feats = reshape(permutedims(feat2T), rm.states_size)
16+
ca_last = evo[end, :]
17+
return (T.(feats), (T.(ca_last),)), st
2018
end
2119

22-
#training dispatch
23-
function train(reca::RECA, target_data, training_method = StandardRidge; kwargs...)
24-
states_new = reca.states_type(reca.nla_type, reca.states, reca.train_data)
25-
return train(training_method, Float32.(states_new), Float32.(target_data); kwargs...)
20+
function (reca::RECACell)(inp::AbstractVector, ps, st::NamedTuple)
21+
ca = st.ca
22+
return reca((inp, (ca,)), ps, st)
2623
end
2724

28-
#predict dispatch
29-
function (reca::RECA)(prediction,
30-
output_layer::AbstractOutputLayer,
31-
initial_conditions = output_layer.last_value,
32-
last_state = zeros(reca.input_encoding.ca_size))
33-
return obtain_prediction(reca, prediction, last_state, output_layer;
34-
initial_conditions = initial_conditions)
35-
end
25+
function RECA(in_dims::IntegerType,
26+
out_dims::IntegerType,
27+
automaton;
28+
input_encoding::AbstractInputEncoding=RandomMapping(),
29+
generations::Integer=8,
30+
state_modifiers=(),
31+
readout_activation=identity)
32+
33+
rm = create_encoding(input_encoding, in_dims, generations)
34+
cell = RECACell(automaton, rm)
35+
36+
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
37+
Tuple(state_modifiers) : (state_modifiers,)
3638

37-
function next_state_prediction!(reca::RECA, x, out, i, args...)
38-
rm = reca.input_encoding
39-
x = encoding(rm, out, x)
40-
ca = CellularAutomaton(reca.automata, x, rm.generations + 1)
41-
ca_states = ca.evolution[2:end, :]
42-
x_new = reshape(transpose(ca_states), rm.states_size)
43-
x = ca.evolution[end, :]
44-
return x, x_new
39+
ro = Readout(rm.states_size => out_dims, readout_activation)
40+
41+
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
4542
end
4643

47-
function RandomMapping(; permutations = 8, expansion_size = 40)
44+
function RandomMapping(; permutations=8, expansion_size=40)
4845
RandomMapping(permutations, expansion_size)
4946
end
5047

51-
function RandomMapping(permutations; expansion_size = 40)
48+
function RandomMapping(permutations; expansion_size=40)
5249
RandomMapping(permutations, expansion_size)
5350
end
5451

55-
function create_encoding(rm::RandomMapping, input_data, generations)
56-
maps = init_maps(size(input_data, 1), rm.permutations, rm.expansion_size)
52+
function create_encoding(rm::RandomMapping, in_dims::IntegerType, generations::IntegerType)
53+
maps = init_maps(in_dims, rm.permutations, rm.expansion_size)
5754
states_size = generations * rm.expansion_size * rm.permutations
5855
ca_size = rm.expansion_size * rm.permutations
59-
return RandomMaps(rm.permutations, rm.expansion_size, generations, maps, states_size,
60-
ca_size)
56+
return RandomMaps(rm.permutations, rm.expansion_size, generations, maps, states_size, ca_size)
6157
end
6258

59+
6360
function reca_create_states(rm::RandomMaps, automata, input_data)
6461
train_time = size(input_data, 2)
6562
states = zeros(rm.states_size, train_time)
@@ -82,21 +79,21 @@ function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
8279
new_tot_enc_vec = copy(tot_encoded_vector)
8380

8481
for i in 1:(rm.permutations)
85-
new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)] = single_encoding(
82+
new_tot_enc_vec[((i-1)*rm.expansion_size+1):(i*rm.expansion_size)] = single_encoding(
8683
input_vector,
87-
new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)],
84+
new_tot_enc_vec[((i-1)*rm.expansion_size+1):(i*rm.expansion_size)],
8885
rm.maps[i,
89-
:])
86+
:])
9087
end
9188

9289
return new_tot_enc_vec
9390
end
9491

95-
#function obtain_res_size(rm::RandomMapping, generations)
96-
# generations*rm.expansion_size*rm.permutations
97-
#end
98-
9992
function single_encoding(input_vector, encoded_vector, map)
93+
@assert length(map) == length(input_vector) """
94+
RandomMaps mismatch: map length = $(length(map)) but input length = $(length(input_vector)).
95+
(Build RandomMaps with in_dims = size(input, 1) used at training time.)
96+
"""
10097
new_enc_vec = copy(encoded_vector)
10198

10299
for i in 1:size(input_vector, 1)

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
5959
export train
6060
export ESN, HybridESN, KnowledgeModel, DeepESN
6161
#reca
62-
export RECA
62+
export RECACell, RECA
6363
export RandomMapping, RandomMaps
6464

6565
end #module

src/extensions/reca.jl

Lines changed: 146 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,50 @@ abstract type AbstractEncodingData end
66
RandomMapping(permutations; expansion_size=40)
77
RandomMapping(;permutations=8, expansion_size=40)
88
9-
Random mapping of the input data directly in the reservoir. The `expansion_size`
10-
determines the dimension of the single reservoir, and `permutations` determines the
11-
number of total reservoirs that will be connected, each with a different mapping.
12-
The detail of this implementation can be found in [1].
9+
Specify the **random input embedding** used by the Cellular Automata reservoir.
10+
Each time step, the input vector of length `in_dims` is randomly placed into a
11+
larger 1D lattice of length `expansion_size`, and this is repeated for
12+
`permutations` independent lattices (blocks). The concatenation of these blocks
13+
forms the CA initial condition of length: `ca_size = expansion_size * permutations`.
14+
The detail of this implementation can be found in [Nichele2017](@cite).
1315
14-
[1] Nichele, Stefano, and Andreas Molund. “Deep reservoir computing using cellular
15-
automata.” arXiv preprint arXiv:1703.02806 (2017).
16+
## Arguments
17+
18+
- `permutations`: number of independent random maps (blocks). Larger
19+
values increase feature diversity and `ca_size` proportionally.
20+
- `expansion_size`: width of each block (the size of a single CA
21+
lattice). Larger values increase the spatial resolution and both `ca_size`
22+
and `states_size`.
23+
24+
## Usage
25+
26+
This is a **configuration object**; it does not perform the mapping by itself.
27+
Create the concrete tables with `create_encoding` and pass them to
28+
[`RECACell`](@ref):
29+
30+
```julia
31+
using ReservoirComputing, CellularAutomata, Random
32+
33+
in_dims = 4
34+
generations = 8
35+
mapping = RandomMapping(permutations=8, expansion_size=40)
36+
37+
enc = ReservoirComputing.create_encoding(mapping, in_dims, generations) # → RandomMaps
38+
cell = RECACell(DCA(90), enc)
39+
40+
rc = ReservoirChain(
41+
StatefulLayer(cell),
42+
Readout(enc.states_size => in_dims; include_collect=true)
43+
)
44+
```
45+
46+
Or let [`RECA`](@ref) do this for you:
47+
48+
```julia
49+
rc = RECA(in_dims=4, out_dims=4, DCA(90);
50+
input_encoding = RandomMapping(permutations=8, expansion_size=40),
51+
generations = 8)
52+
```
1653
"""
1754
struct RandomMapping{I,T} <: AbstractInputEncoding
1855
permutations::I
@@ -28,28 +65,112 @@ struct RandomMaps{T,E,G,M,S} <: AbstractEncodingData
2865
ca_size::S
2966
end
3067

31-
abstract type AbstractReca end
68+
@doc raw"""
69+
RECACell(automaton, enc::RandomMaps)
70+
71+
Cellular Automata (CA)–based reservoir recurrent cell. At each time step,
72+
the input vector is randomly embedded into a CA configuration, the CA is
73+
evolved for a fixed number of generations, and the flattened CA evolution
74+
is emitted as the reservoir state. The last CA configuration is carried
75+
to the next step. For more details please refer to [Nichele2017](@cite),
76+
and [Yilmaz2014](@cite).
77+
78+
## Arguments
79+
80+
- `automaton`: A cellular automaton rule/object from `CellularAutomata.jl`
81+
(e.g., `DCA(90)`, `DCA(30)`, …).
82+
83+
- `enc`: Precomputed random-mapping/encoding metadata given as a
84+
[`RandomMapping`](@ref).
85+
86+
## Inputs
87+
88+
- Case A: a single input vector `x` with length
89+
`in_dims`. The cell internally uses the stored CA state (`st.ca`) as the
90+
previous configuration.
91+
92+
- Case B: a tuple `(x, (ca,))` where `x` is as above and
93+
`ca` has length `enc.ca_size`.
94+
95+
## Computation
96+
97+
1. Random embedding of `x` into a CA initial condition `c₀` using `enc.maps`
98+
across `enc.permutations` blocks of length `enc.expansion_size`.
99+
100+
2. CA evolution for `G = enc.generations` steps with the given `automaton`,
101+
producing an evolution matrix `E ∈ ℝ^{(G+1) × ca_size}` where `E[1,:] = c₀`
102+
and `E[t+1,:] = F(E[t,:])`.
103+
104+
3. Feature vector is the flattened stack of `E[2:end, :]` (dropping the
105+
initial row), shaped as a column vector of length `enc.states_size`.
106+
107+
4. Carry is the final CA configuration `E[end, :]`.
108+
109+
## Returns
110+
111+
- Output: `(h, (caₙ,))` where
112+
* `h` has length `enc.states_size` (the CA features),
113+
* `caₙ` has length `enc.ca_size` (next carry).
114+
- Updated (unchanged) cell state (parameters-free layer state).
115+
116+
## Parameters & State
117+
118+
- Parameters: none
119+
- State: `(ca = zeros(Float32, enc.ca_size))`
32120
33121
"""
34-
RECA(train_data,
35-
automata;
36-
generations = 8,
122+
@concrete struct RECACell <: AbstractReservoirRecurrentCell
123+
automaton
124+
enc <: RandomMaps
125+
end
126+
127+
Base.show(io::IO, reca::RECACell) = print(io,
128+
"RECACell(in ⇒ ", reca.enc.ca_size, ", out=", reca.enc.states_size,
129+
", gens=", reca.enc.generations, ", perms=", reca.enc.permutations,
130+
", exp=", reca.enc.expansion_size, ")")
131+
132+
initialparameters(::AbstractRNG, ::RECACell) = NamedTuple()
133+
134+
function initialstates(::AbstractRNG, reca::RECACell)
135+
return (ca=zeros(Float32, reca.enc.ca_size),)
136+
end
137+
138+
@doc raw"""
139+
RECA(in_dims, out_dims, automaton;
37140
input_encoding=RandomMapping(),
38-
nla_type = NLADefault(),
39-
states_type = StandardStates())
141+
generations=8, state_modifiers=(),
142+
readout_activation=identity)
143+
144+
Construct a cellular–automata reservoir model.
145+
146+
At each time step the input vector is randomly embedded into a Cellular
147+
Automaton (CA) lattice, the CA is evolved for `generations` steps, and the
148+
flattened evolution (excluding the initial row) is used as the reservoir state.
149+
A linear [`Readout`](@ref) maps these features to `out_dims`.
40150
41-
[1] Yilmaz, Ozgur. “_Reservoir computing using cellular automata._”
42-
arXiv preprint arXiv:1410.0162 (2014).
151+
!!! note
152+
This constructor is only available when the `CellularAutomata.jl` package is
153+
loaded.
43154
44-
[2] Nichele, Stefano, and Andreas Molund. “_Deep reservoir computing using cellular
45-
automata._” arXiv preprint arXiv:1703.02806 (2017).
155+
## Arguments
156+
157+
- `in_dims`: Number of input features (rows of training data).
158+
- `out_dims`: Number of output features (rows of target data).
159+
- `automaton`: A CA rule/object from `CellularAutomata.jl` (e.g. `DCA(90)`,
160+
`DCA(30)`, …).
161+
162+
## Keyword Arguments
163+
164+
- `input_encoding`: Random embedding spec with
165+
fields `permutations` and `expansion_size`.
166+
Default is `RandomMapping()`.
167+
- `generations`: Number of CA generations to evolve per time step.
168+
Default is 8.
169+
- `state_modifiers`: Optional tuple/vector of additional layers applied
170+
after the CA cell and before the readout (e.g., `NLAT2()`, `Pad(1.0)`,
171+
custom transforms, etc.). Functions are wrapped automatically.
172+
Default is none.
173+
- `readout_activation`: Activation applied by the readout
174+
Default is `identity`.
46175
"""
47-
struct RECA{S,R,E,N,T,Q} <: AbstractReca
48-
#res_size::I
49-
train_data::S
50-
automata::R
51-
input_encoding::E
52-
nla_type::N
53-
states::T
54-
states_type::Q
55-
end
176+
RECA(::Any...) = error("RECA requires CellularAutomata.jl; use it to enable this constructor.")

0 commit comments

Comments
 (0)