Skip to content

Commit 819f1e8

Browse files
fix: make docs and tests pass
1 parent 0656d50 commit 819f1e8

File tree

17 files changed

+371
-49
lines changed

17 files changed

+371
-49
lines changed

docs/pages.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ pages = [
55
"Chaos forecasting with an ESN" => "tutorials/lorenz_basic.md",
66
#"Using Different Training Methods" => "esn_tutorials/different_training.md",
77
"Deep Echo State Networks" => "tutorials/deep_esn.md",
8-
"Hybrid Echo State Networks" => "tutorials/hybrid.md",
8+
#"Hybrid Echo State Networks" => "tutorials/hybrid.md",
99
"Reservoir Computing with Cellular Automata" => "tutorials/reca.md"],
1010
"API Documentation" => Any[
1111
"Layers" => "api/layers.md",
1212
"Models" => "api/models.md",
13-
"States" => "api/states.md",
13+
"Utilities" => "api/utils.md",
1414
"Train" => "api/train.md",
1515
"Predict" => "api/predict.md",
16+
"States" => "api/states.md",
1617
"Initializers" => "api/inits.md"],
1718
"References" => "references.md"
1819
]

docs/src/api/layers.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,27 @@
22

33
## Base Layers
44

5-
```@doc
5+
```@docs
66
ReservoirChain
77
Collect
88
StatefulLayer
9-
LinearReadout
109
```
1110

12-
## External Layers
11+
## Readout Layers
1312

1413
```@docs
14+
LinearReadout
1515
SVMReadout
1616
```
1717

1818
## Echo State Networks
1919

20-
```@doc
20+
```@docs
2121
ESNCell
2222
```
23+
24+
## REservoir computing with cellualr automata
25+
26+
```@docs
27+
RECACell
28+
```

docs/src/api/models.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,3 @@ The input encodings are the equivalent of the input matrices of the ESNs. These
2525
```@docs
2626
RandomMapping
2727
```
28-
29-
The training and prediction follow the same workflow as the ESN. It is important to note that currently we were unable to find any papers using these models with a `Generative` approach for the prediction, so full support is given only to the `Predictive` method.

docs/src/api/utils.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Utilities
2+
3+
```@docs
4+
collectstates
5+
```

docs/src/tutorials/lorenz_basic.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ ReservoirComputing.jl provides
133133
additional utilities functions for autoregressive forecasting:
134134

135135
```@example lorenz
136-
pred_length
137-
output, st = predict(esn, predict_len, ps, st; initialdata=test[:, 1])
136+
output, st = predict(esn, predict_len, ps, st; initialdata=test_data[:, 1])
138137
```
139138

140139
To inspect the results, they can easily be plotted using an external library.

docs/src/tutorials/scratch.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ utilities)
99
## Using provided layers: ReservoirChain, ESNCell, and LinearReadout
1010

1111
The library provides a [`ReservoirChain`](@ref), which is virtually
12-
equivalent to Lux's [`Chain`](@extref). Passing layers, or functions,
12+
equivalent to Lux's `Chain`. Passing layers, or functions,
1313
to the chain will concatenate them, and will allow the flow of the input
1414
data through the model.
1515

@@ -48,7 +48,7 @@ the weights of the input layer:
4848

4949
```@example scratch
5050
using Random
51-
Random.seed(43)
51+
Random.seed!(43)
5252
5353
rng = MersenneTwister(17)
5454
ps_s, st_s = setup(rng, esn_scratch)

src/extensions/reca.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ cell = RECACell(DCA(90), enc)
3939
4040
rc = ReservoirChain(
4141
StatefulLayer(cell),
42-
Readout(enc.states_size => in_dims; include_collect = true)
42+
LinearReadout(enc.states_size => in_dims; include_collect = true)
4343
)
4444
```
4545
@@ -148,7 +148,7 @@ Construct a cellular–automata reservoir model.
148148
At each time step the input vector is randomly embedded into a Cellular
149149
Automaton (CA) lattice, the CA is evolved for `generations` steps, and the
150150
flattened evolution (excluding the initial row) is used as the reservoir state.
151-
A linear [`Readout`](@ref) maps these features to `out_dims`.
151+
A linear [`LinearReadout`](@ref) maps these features to `out_dims`.
152152
153153
!!! note
154154
This constructor is only available when the `CellularAutomata.jl` package is

src/layers/basic.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ abstract type AbstractReservoirTrainableLayer <: AbstractLuxLayer end
1111
Linear readout layer with optional bias and elementwise activation. Intended as
1212
the final, trainable mapping from collected features (e.g., reservoir state) to
1313
outputs. When `include_collect=true`, training will collect features immediately
14-
before this layer (logically inserting a [`Collect()`](@ref) right before it).
14+
before this layer (logically inserting a [`Collect`](@ref) right before it).
1515
1616
## Equation
1717
@@ -29,7 +29,7 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
2929
3030
- `use_bias`: Include an additive bias vector `b`. Default: `false`.
3131
- `include_collect`: If `true` (default), training collects features immediately
32-
before this layer (as if a [`Collect()`](@ref) were inserted right before it).
32+
before this layer (as if a [`Collect`](@ref) were inserted right before it).
3333
3434
## Parameters
3535
@@ -42,9 +42,11 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
4242
4343
## Notes
4444
45-
- In ESN workflows, readout weights are typically set via ridge regression in
46-
`train!(...)`. Therefore, how `Readout` gets initialized is of no consequence.
47-
- If you set `include_collect=false`, make sure a [`Collect()`](@ref) appears earlier in the chain.
45+
- In ESN workflows, readout weights are typically replaced via ridge regression in
46+
[`train!`](@ref). Therefore, how `LinearReadout` gets initialized is of no consequence.
47+
Additionally, the dimesions will also not be taken into account, as [`train!`](@ref)
48+
will replace the weights.
49+
- If you set `include_collect=false`, make sure a [`Collect`](@ref) appears earlier in the chain.
4850
Otherwise training may operate on the post-readout signal,
4951
which is usually unintended.
5052
"""
@@ -106,9 +108,9 @@ end
106108
Collect()
107109
108110
Marker layer that passes data through unchanged but marks a feature
109-
checkpoint for [`collectstates`](@ref). At each time step, whenever a `Collect()` is
111+
checkpoint for [`collectstates`](@ref). At each time step, whenever a `Collect` is
110112
encountered in the chain, the current vector is recorded as part of the feature
111-
vector used to train the readout. If multiple `Collect()` layers exist, their
113+
vector used to train the readout. If multiple `Collect` layers exist, their
112114
vectors are concatenated with `vcat` in order of appearance.
113115
114116
## Arguments
@@ -137,13 +139,13 @@ vectors are concatenated with `vcat` in order of appearance.
137139
138140
## Notes
139141
140-
- When used with a single `Collect()` before a [`LinearReadout`](@ref), training uses exactly
142+
- When used with a single `Collect` before a [`LinearReadout`](@ref), training uses exactly
141143
the tensor right before the readout (e.g., the reservoir state).
142-
- With **multiple** `Collect()` layers (e.g., after different submodules), the
144+
- With **multiple** `Collect` layers (e.g., after different submodules), the
143145
per-step features are `vcat`-ed in chain order to form one feature vector.
144146
- If the readout is constructed with `include_collect=true`, an *implicit*
145147
collection point is assumed immediately before the readout. Use an explicit
146-
`Collect()` only when you want to control where/what is collected (or to stack
148+
`Collect` only when you want to control where/what is collected (or to stack
147149
multiple features).
148150
149151
```julia
@@ -167,9 +169,9 @@ Base.show(io::IO, cl::Collect) = print(io, "Collection point of states")
167169
collectstates(rc, data, ps, st)
168170
169171
Run the sequence `data` once through the reservoir chain `rc`, advancing the
170-
model state over time, and collect feature vectors at every [`Collect()`](@ref) layer.
171-
If more than one [`Collect()`](ref) is encountered in a step, their vectors are
172-
concatenated with `vcat` in order of appearance. If no [`Collect()`](@ref) is seen
172+
model state over time, and collect feature vectors at every [`Collect`](@ref) layer.
173+
If more than one [`Collect`](@ref) is encountered in a step, their vectors are
174+
concatenated with `vcat` in order of appearance. If no [`Collect`](@ref) is seen
173175
in a step, the feature defaults to the final vector exiting the chain for
174176
that time step.
175177
@@ -180,15 +182,15 @@ that time step.
180182
181183
## Arguments
182184
183-
- `rc`: A [`ReservoirChain`](@ref) (or compatible [`AbstractLuxLayer`](@extref) with `.layers`).
185+
- `rc`: A [`ReservoirChain`](@ref) (or compatible `AbstractLuxLayer` with `.layers`).
184186
- `data`: Input sequence of shape `(in_dims, T)`, where columns are time steps.
185187
- `ps`, `st`: Current parameters and state for `rc`.
186188
187189
## Returns
188190
189191
- `states`: Reservoir states, i.e. a feature matrix with one column per
190192
time step. The feature dimension `n_features` equals the vertical concatenation
191-
of all vectors captured at [`Collect()`](@ref) layers in that step.
193+
of all vectors captured at [`Collect`](@ref) layers in that step.
192194
- `st`: Updated model states.
193195
194196
"""

src/layers/esn_cell.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ Echo State Network (ESN) recurrent cell with optional leaky integration.
2525
2626
- `use_bias`: Whether to include a bias term. Default: `false`.
2727
- `init_bias`: Initializer for the bias. Used only if `use_bias=true`.
28-
Default is [`rand32`](@extref).
28+
Default is `rand32`.
2929
- `init_reservoir`: Initializer for the reservoir matrix `W_res`.
3030
Default is [`rand_sparse`](@ref).
3131
- `init_input`: Initializer for the input matrix `W_in`.
3232
Default is [`scaled_rand`](@ref).
3333
- `init_state`: Initializer for the hidden state when an external
34-
state is not provided. Default is [`randn32`](@extref).
34+
state is not provided. Default is `randn32`.
3535
- `leak_coefficient`: Leak rate `α ∈ (0,1]`. Default: `1.0`.
3636
3737
## Inputs

src/layers/lux_layers.jl

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@ end
88
Base.show(io::IO, wf::WrappedFunction) = print(io, "WrappedFunction(", wf.func, ")")
99

1010
# adapted from lux layers/recurrent StatefulRecurrentCell
11+
@doc raw"""
12+
StatefulLayer(cell::AbstractReservoirRecurrentCell)
13+
14+
A lightweight wrapper that makes a recurrent cell carry its imput state to the
15+
next step.
16+
17+
## Arguments
18+
19+
- `cell`: Any `AbstractReservoirRecurrentCell` (e.g. [`ESNCell`](@ref)).
20+
21+
## States
22+
23+
- `cell`: internal states for the wrapped `cell` (e.g., RNG replicas, etc.).
24+
- `carry`: the per-sequence hidden state; initialized to `nothing`.
25+
26+
"""
1127
@concrete struct StatefulLayer <: AbstractLuxWrapperLayer{:cell}
1228
cell <: AbstractReservoirRecurrentCell
1329
end
@@ -29,15 +45,76 @@ function applyrecurrentcell(sl::AbstractReservoirRecurrentCell, inp, ps, st, ::N
2945
return apply(sl, inp, ps, st)
3046
end
3147

32-
###build the ReservoirChain
48+
@doc raw"""
49+
ReservoirChain(layers...; name=nothing)
50+
ReservoirChain(xs::AbstractVector; name=nothing)
51+
ReservoirChain(nt::NamedTuple; name=nothing)
52+
ReservoirChain(; name=nothing, kwargs...)
3353
34-
#abstract type RCLayer <: AbstractLuxLayer end
35-
#abstract type RCContainerLayer <: AbstractLuxContainerLayer end
54+
A lightweight, Lux-compatible container that composes a sequence of layers
55+
and executes them in order. The implementation of `ReservoirChain` is
56+
equivalent to Lux's own `Chain`.
3657
37-
"""
38-
ReservoirChain(layers...)
58+
## Construction
59+
60+
You can build a chain from:
61+
62+
- **Positional layers:** `ReservoirChain(l1, l2, ...)`
63+
- **A vector of layers:** `ReservoirChain([l1, l2, ...])`
64+
- **A named tuple of layers:** `ReservoirChain((; layer_a=l1, layer_b=l2))`
65+
- **Keywords (sugar for a named tuple):** `ReservoirChain(; layer_a=l1, layer_b=l2)`
66+
67+
In all cases, function objects are automatically wrapped via `WrappedFunction`
68+
so they can participate like regular layers. If a [`LinearReadout`](@ref) with
69+
`include_collect=true` is present, the chain automatically inserts a [`Collect`](@ref)
70+
layer immediately before that readout.
71+
72+
Use `name` to optionally tag the chain instance.
73+
74+
## Inputs
75+
76+
`(x, ps, st)` where:
77+
78+
- `x`: input to the first layer.
79+
- `ps`: parameters as a named tuple with the same fields and order as the chain's layers.
80+
- `st`: states as a named tuple with the same fields and order as the chain's layers.
81+
82+
The call `(c::ReservoirChain)(x, ps, st)` forwards `x` through each layer:
83+
`(x, ps_i, st_i) -> (x_next, st_i′)` and returns the final output and the
84+
updated states for every layer.
85+
86+
## Returns
87+
88+
- `(y, st′)` where `y` is the output of the last layer and `st′` is a named
89+
tuple collecting the updated states for each layer.
90+
91+
## Parameters
92+
93+
- A `NamedTuple` whose fields correspond 1:1 with the layers. Each field
94+
holds the parameters for that layer.
95+
- Field names are generated as `:layer_1, :layer_2, ...` when constructed
96+
positionally, or preserved when you pass a `NamedTuple`/keyword constructor.
97+
98+
## States
99+
100+
- A `NamedTuple` whose fields correspond 1:1 with the layers. Each field
101+
holds the state for that layer.
102+
103+
## Layer access & indexing
104+
105+
- `c[i]`: get the *i*-th layer (1-based).
106+
- `c[indices]`: return a new `ReservoirChain` formed by selecting a subset of layers.
107+
- `getproperty(c, :layer_k)`: access layer `k` by its generated/explicit name.
108+
- `length(c)`, `firstindex(c)`, `lastindex(c)`: standard collection interfaces.
109+
110+
## Notes
111+
112+
- **Function wrapping:** Any plain `Function` in the constructor is wrapped as
113+
`WrappedFunction(f)`. Non-layer, non-function objects will error.
114+
- **Auto-collect for readouts:** When a [`LinearReadout`](@ref) has
115+
`include_collect=true`, the constructor expands it to `(Collect(), readout)`
116+
so that downstream tooling can capture features consistently.
39117
40-
A simple container that holds a sequence of layers
41118
"""
42119
@concrete struct ReservoirChain <: AbstractLuxWrapperLayer{:layers}
43120
layers <: NamedTuple
@@ -82,7 +159,7 @@ end
82159
(c::ReservoirChain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st)
83160

84161
@generated function applychain(
85-
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}
162+
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}
86163
) where {fields}
87164
@assert isa(fields, NTuple{<:Any, Symbol})
88165
N = length(fields)

0 commit comments

Comments
 (0)