Skip to content

Commit 2b99b00

Browse files
docs: more work on docstrings
1 parent 8092f24 commit 2b99b00

File tree

4 files changed

+113
-28
lines changed

4 files changed

+113
-28
lines changed

docs/src/api/train.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22

33
```@docs
44
train!
5+
train
6+
```
7+
8+
## Training methods
9+
10+
```@docs
511
StandardRidge
612
```

src/models/esn_generics.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,41 @@ function addreadout!(::AbstractEchoStateNetwork, output_matrix::AbstractMatrix,
6262
return merge(ps, (readout=new_readout,)), st
6363
end
6464

65+
@doc raw"""
66+
resetcarry!(rng, esn::AbstractEchoStateNetwork, st; init_carry=nothing)
67+
resetcarry!(rng, esn::AbstractEchoStateNetwork, ps, st; init_carry=nothing)
68+
69+
Reset (or set) the hidden-state carry of a model in the echo state network family.
70+
71+
If an existing carry is present in `st.cell.carry`, its leading dimension is used to
72+
infer the state size. Otherwise the reservoir output size is taken from
73+
`esn.cell.cell.out_dims`. When `init_carry=nothing`, the carry is cleared; the initialzer
74+
from the struct construction will then be used. When a
75+
function is provided, it is called to create a new initial hidden state.
76+
77+
## Arguments
78+
79+
- `rng`: Random number generator (used if a new carry is sampled/created).
80+
- `esn`: An echo state network model.
81+
- `st`: Current model states.
82+
- `ps`: Optional model parameters. Returned unchanged.
83+
84+
## Keyword arguments
85+
86+
- `init_carry`: Controls the initialization of the new carry.
87+
- `nothing` (default): remove/clear the carry (forces the cell to reinitialize
88+
from its own `init_state` on next use).
89+
- `f`: a function called as `f(rng, sz, batch)`, following standard from
90+
[WeightInitializers.jl](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers)
91+
92+
## Returns
93+
94+
- `resetcarry!(rng, esn, st; ...) -> st′`:
95+
Updated states with `st′.cell.carry` set to `nothing` or `(h0,)`.
96+
- `resetcarry!(rng, esn, ps, st; ...) -> (ps, st′)`:
97+
Same as above, but also returns the unchanged `ps` for convenience.
98+
99+
"""
65100
function resetcarry!(rng::AbstractRNG, esn::AbstractEchoStateNetwork, st; init_carry=nothing)
66101
carry = get(st.cell, :carry, nothing)
67102
if carry === nothing

src/predict.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,26 @@ sequence.
99
## 1) Auto-regressive rollout
1010
1111
**Behavior**
12+
1213
- Rolls the model forward for `steps` time steps.
1314
- At each step, the model’s output becomes the next input.
1415
1516
### Arguments
17+
1618
- `rc`: The reservoir chain / model.
17-
- `steps::Integer`: Number of time steps to generate.
18-
- `ps`: Model parameters (from `setup` or after `train!`).
19-
- `st`: Model state (carry, RNG replicas, etc.), threaded across time.
19+
- `steps`: Number of time steps to generate.
20+
- `ps`: Model parameters.
21+
- `st`: Model states.
2022
2123
### Keyword Arguments
24+
2225
- `initialdata=nothing`: Column vector used as the first input.
26+
Has to be provided.
2327
2428
### Returns
2529
26-
- `Y`: Generated outputs of shape `(out_dims, steps)`.
27-
- `st_out`: Final model state after `steps` steps.
30+
- `output`: Generated outputs of shape `(out_dims, steps)`.
31+
- `st`: Final model state after `steps` steps.
2832
2933
3034
## 2) Teacher-forced / point-by-point
@@ -37,12 +41,12 @@ sequence.
3741
- `rc`: The reservoir chain / model.
3842
- `data`: Input sequence of shape `(in_dims, T)` (columns are time).
3943
- `ps`: Model parameters.
40-
- `st`: Initial model state before processing `data`.
44+
- `st`: Model states.
4145
4246
### Returns
4347
44-
- `Y::AbstractMatrix`: Outputs for each input column, shape `(out_dims, T)`.
45-
- `st_out`: Final model state after consuming all `T` columns.
48+
- `output`: Outputs for each input column, shape `(out_dims, T)`.
49+
- `st`: Updated minal model states.
4650
"""
4751
function predict(rc::AbstractLuxLayer,
4852
steps::Integer, ps, st; initialdata::AbstractVector)

src/train.jl

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
33
StandardRidge([Type], [reg])
44
5-
Returns a training method for `train` based on ridge regression.
6-
The equations for ridge regression are as follows:
5+
Ridge regression method.
6+
7+
## Equations
78
89
```math
910
\mathbf{w} = (\mathbf{X}^\top \mathbf{X} +
1011
\lambda \mathbf{I})^{-1} \mathbf{X}^\top \mathbf{y}
1112
```
1213
13-
# Arguments
14+
## Arguments
15+
1416
- `Type`: type of the regularization argument. Default is inferred internally,
1517
there's usually no need to tweak this
1618
- `reg`: regularization coefficient. Default is set to 0.0 (linear regression).
@@ -40,52 +42,90 @@ end
4042
_set_readout(ps, m::ReservoirChain, W) = first(addreadout!(m, W, ps, NamedTuple()))
4143

4244
@doc raw"""
43-
train!(rc::ReservoirChain, train_data, target_data, ps, st,
44-
sr::StandardRidge=StandardRidge(0.0);
45-
washout::Int=0, return_states::Bool=false)
45+
train!(rc, train_data, target_data, ps, st,
46+
train_method=StandardRidge(0.0);
47+
washout=0, return_states=false)
4648
47-
Trains the Reservoir Computer by creating the reservoir states from `train_data`,
48-
and then fiting the last [`LinearReadout`](@ref) layer by (ridge)
49-
linear regression onto `target_data`. The learned weights are written into `ps`, and.
50-
The returned state is the final state after running through the full sequence.
49+
Trains a given reservoir computing by creating the reservoir states from `train_data`,
50+
and then fiting the readout layer using `target_data` as target.
51+
The learned weights/layer are written into `ps`.
5152
5253
## Arguments
5354
54-
- `rc`: A [`ReservoirChain`](@ref) whose last trainable layer is a `LinearReadout`.
55-
- `train_data`: input sequence (columns are time steps).
55+
- `rc`: A reservoir computing model, either provided by ReservoirComputing.jl
56+
or built with [`ReservoirChain`](@ref). Must contain a trainable layer
57+
(for example [`LinearReadout`](@ref)), and a collection point [`Collect`](@ref).
58+
- `train_data`: input sequence where columns are time steps.
5659
- `target_data`: targets aligned with `train_data`.
57-
- `ps, st`: current parameters and state.
58-
- `sr`: ridge spec, e.g. `StandardRidge(1e-4)`; `0.0` gives ordinary least squares.
60+
- `ps`: model parameters.
61+
- `st`: model states.
62+
- `train_method`: training algorithm. Default is [`StandardRidge`](@ref).
5963
6064
## Keyword arguments
6165
6266
- `washout`: number of initial time steps to discard (applied equally to features
63-
and targets). Must satisfy `0 ≤ washout < T`. Default `0`.
67+
and targets). Default `0`.
6468
- `return_states`: if `true`, also returns the feature matrix used
6569
for the fit.
70+
- `kwargs...`: additional keyword arguments for the training algorithm, if needed.
71+
Defaults vary according to the different training method.
6672
6773
## Returns
6874
69-
- `(ps2, st_after)` — updated parameters and the final model state.
70-
- If `return_states=true`, also returns `states_used`.
75+
- `(ps, st)`: updated model parameters and states.
76+
- `(ps, st), states`: If `return_states=true`.
7177
7278
## Notes
7379
7480
- Features are produced by `collectstates(rc, train_data, ps, st)`. If you rely on
7581
the implicit collection of a [`LinearReadout`](@ref), make sure that readout was created with
76-
`include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the chain.
82+
`include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the
83+
[`ReservoirChain`](@ref).
7784
"""
7885
function train!(rc, train_data, target_data, ps, st,
7986
train_method=StandardRidge(0.0);
80-
washout::Int=0, return_states::Bool=false)
87+
washout::Int=0, return_states::Bool=false, kwargs...)
8188
states, st_after = collectstates(rc, train_data, ps, st)
8289
states_wo, traindata_wo = washout > 0 ? _apply_washout(states, target_data, washout) :
8390
(states, target_data)
84-
output_matrix = train(train_method, states_wo, traindata_wo)
91+
output_matrix = train(train_method, states_wo, traindata_wo; kwargs...)
8592
ps2, st_after = addreadout!(rc, output_matrix, ps, st_after)
8693
return return_states ? ((ps2, st_after), states_wo) : (ps2, st_after)
8794
end
8895

96+
@doc raw"""
97+
train(train_method, states, target_data; kwargs...)
98+
99+
Lower level training hook to fit a readout from precomputed
100+
reservoir features and given targets.
101+
102+
Dispatching on this method with different training methods
103+
allows one to hook directly into [`train!`](@ref) without
104+
additional changes.
105+
106+
## Arguments
107+
108+
- `train_method`: An object describing the training algorithm and its hyperparameters
109+
(e.g. regularization strength, solver choice, constraints).
110+
- `states`: Feature matrix with reservoir states (ie. obtained with [`collectstates`](@ref)).
111+
Shape `(n_features, T)`, where `T` is the number of samples (e.g. time steps).
112+
- `target_data`: Target matrix aligned with `states`. Shape `(n_outputs, T)`.
113+
114+
## Returns
115+
116+
- `output_weights`: Trained readout. Should be a forward method to be hooked into a
117+
layer. For instance, in case of linear regression `output_weights` is a mtrix
118+
consumable by [`LinearReadout`](@ref).
119+
120+
## Notes
121+
122+
- Any sequence pre-processing (e.g. washout) should be handled by the caller before
123+
invoking `train`. See [`train!`](@ref) for an end-to-end workflow.
124+
- For very long `T`, consider chunked or iterative solvers to reduce memory usage.
125+
- If your approach returns additional artifacts (e.g. diagnostics), prefer storing
126+
them inside `train_method` or exposing a separate API; keep `train`’s return
127+
value as the forward method only.
128+
"""
89129
function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray)
90130
n_states = size(states, 1)
91131
A = [states'; sqrt(sr.reg) * I(n_states)]

0 commit comments

Comments
 (0)