|
2 | 2 |
|
3 | 3 | StandardRidge([Type], [reg]) |
4 | 4 |
|
5 | | -Returns a training method for `train` based on ridge regression. |
6 | | -The equations for ridge regression are as follows: |
| 5 | +Ridge regression method. |
| 6 | +
|
| 7 | +## Equations |
7 | 8 |
|
8 | 9 | ```math |
9 | 10 | \mathbf{w} = (\mathbf{X}^\top \mathbf{X} + |
10 | 11 | \lambda \mathbf{I})^{-1} \mathbf{X}^\top \mathbf{y} |
11 | 12 | ``` |
12 | 13 |
|
13 | | -# Arguments |
| 14 | +## Arguments |
| 15 | +
|
14 | 16 | - `Type`: type of the regularization argument. Default is inferred internally, |
15 | 17 | there's usually no need to tweak this |
16 | 18 | - `reg`: regularization coefficient. Default is set to 0.0 (linear regression). |
|
40 | 42 | _set_readout(ps, m::ReservoirChain, W) = first(addreadout!(m, W, ps, NamedTuple())) |
41 | 43 |
|
42 | 44 | @doc raw""" |
43 | | - train!(rc::ReservoirChain, train_data, target_data, ps, st, |
44 | | - sr::StandardRidge=StandardRidge(0.0); |
45 | | - washout::Int=0, return_states::Bool=false) |
| 45 | + train!(rc, train_data, target_data, ps, st, |
| 46 | + train_method=StandardRidge(0.0); |
| 47 | + washout=0, return_states=false) |
46 | 48 |
|
47 | | -Trains the Reservoir Computer by creating the reservoir states from `train_data`, |
48 | | -and then fiting the last [`LinearReadout`](@ref) layer by (ridge) |
49 | | -linear regression onto `target_data`. The learned weights are written into `ps`, and. |
50 | | -The returned state is the final state after running through the full sequence. |
| 49 | +Trains a given reservoir computing by creating the reservoir states from `train_data`, |
| 50 | +and then fiting the readout layer using `target_data` as target. |
| 51 | +The learned weights/layer are written into `ps`. |
51 | 52 |
|
52 | 53 | ## Arguments |
53 | 54 |
|
54 | | -- `rc`: A [`ReservoirChain`](@ref) whose last trainable layer is a `LinearReadout`. |
55 | | -- `train_data`: input sequence (columns are time steps). |
| 55 | +- `rc`: A reservoir computing model, either provided by ReservoirComputing.jl |
| 56 | + or built with [`ReservoirChain`](@ref). Must contain a trainable layer |
| 57 | + (for example [`LinearReadout`](@ref)), and a collection point [`Collect`](@ref). |
| 58 | +- `train_data`: input sequence where columns are time steps. |
56 | 59 | - `target_data`: targets aligned with `train_data`. |
57 | | -- `ps, st`: current parameters and state. |
58 | | -- `sr`: ridge spec, e.g. `StandardRidge(1e-4)`; `0.0` gives ordinary least squares. |
| 60 | +- `ps`: model parameters. |
| 61 | +- `st`: model states. |
| 62 | +- `train_method`: training algorithm. Default is [`StandardRidge`](@ref). |
59 | 63 |
|
60 | 64 | ## Keyword arguments |
61 | 65 |
|
62 | 66 | - `washout`: number of initial time steps to discard (applied equally to features |
63 | | - and targets). Must satisfy `0 ≤ washout < T`. Default `0`. |
| 67 | + and targets). Default `0`. |
64 | 68 | - `return_states`: if `true`, also returns the feature matrix used |
65 | 69 | for the fit. |
| 70 | +- `kwargs...`: additional keyword arguments for the training algorithm, if needed. |
| 71 | + Defaults vary according to the different training method. |
66 | 72 |
|
67 | 73 | ## Returns |
68 | 74 |
|
69 | | -- `(ps2, st_after)` — updated parameters and the final model state. |
70 | | -- If `return_states=true`, also returns `states_used`. |
| 75 | +- `(ps, st)`: updated model parameters and states. |
| 76 | +- `(ps, st), states`: If `return_states=true`. |
71 | 77 |
|
72 | 78 | ## Notes |
73 | 79 |
|
74 | 80 | - Features are produced by `collectstates(rc, train_data, ps, st)`. If you rely on |
75 | 81 | the implicit collection of a [`LinearReadout`](@ref), make sure that readout was created with |
76 | | - `include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the chain. |
| 82 | + `include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the |
| 83 | + [`ReservoirChain`](@ref). |
77 | 84 | """ |
78 | 85 | function train!(rc, train_data, target_data, ps, st, |
79 | 86 | train_method=StandardRidge(0.0); |
80 | | - washout::Int=0, return_states::Bool=false) |
| 87 | + washout::Int=0, return_states::Bool=false, kwargs...) |
81 | 88 | states, st_after = collectstates(rc, train_data, ps, st) |
82 | 89 | states_wo, traindata_wo = washout > 0 ? _apply_washout(states, target_data, washout) : |
83 | 90 | (states, target_data) |
84 | | - output_matrix = train(train_method, states_wo, traindata_wo) |
| 91 | + output_matrix = train(train_method, states_wo, traindata_wo; kwargs...) |
85 | 92 | ps2, st_after = addreadout!(rc, output_matrix, ps, st_after) |
86 | 93 | return return_states ? ((ps2, st_after), states_wo) : (ps2, st_after) |
87 | 94 | end |
88 | 95 |
|
| 96 | +@doc raw""" |
| 97 | + train(train_method, states, target_data; kwargs...) |
| 98 | +
|
| 99 | +Lower level training hook to fit a readout from precomputed |
| 100 | +reservoir features and given targets. |
| 101 | +
|
| 102 | +Dispatching on this method with different training methods |
| 103 | +allows one to hook directly into [`train!`](@ref) without |
| 104 | +additional changes. |
| 105 | +
|
| 106 | +## Arguments |
| 107 | +
|
| 108 | +- `train_method`: An object describing the training algorithm and its hyperparameters |
| 109 | + (e.g. regularization strength, solver choice, constraints). |
| 110 | +- `states`: Feature matrix with reservoir states (ie. obtained with [`collectstates`](@ref)). |
| 111 | + Shape `(n_features, T)`, where `T` is the number of samples (e.g. time steps). |
| 112 | +- `target_data`: Target matrix aligned with `states`. Shape `(n_outputs, T)`. |
| 113 | +
|
| 114 | +## Returns |
| 115 | +
|
| 116 | +- `output_weights`: Trained readout. Should be a forward method to be hooked into a |
| 117 | + layer. For instance, in case of linear regression `output_weights` is a mtrix |
| 118 | + consumable by [`LinearReadout`](@ref). |
| 119 | +
|
| 120 | +## Notes |
| 121 | +
|
| 122 | +- Any sequence pre-processing (e.g. washout) should be handled by the caller before |
| 123 | + invoking `train`. See [`train!`](@ref) for an end-to-end workflow. |
| 124 | +- For very long `T`, consider chunked or iterative solvers to reduce memory usage. |
| 125 | +- If your approach returns additional artifacts (e.g. diagnostics), prefer storing |
| 126 | + them inside `train_method` or exposing a separate API; keep `train`’s return |
| 127 | + value as the forward method only. |
| 128 | +""" |
89 | 129 | function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray) |
90 | 130 | n_states = size(states, 1) |
91 | 131 | A = [states'; sqrt(sr.reg) * I(n_states)] |
|
0 commit comments