Skip to content

Commit 11ef816

Browse files
refac: remove old ESN and DeepESN, implement new APIs
1 parent 152f83c commit 11ef816

File tree

10 files changed

+250
-849
lines changed

10 files changed

+250
-849
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,20 @@ possible parameters:
105105
```julia
106106
input_size = 3
107107
res_size = 300
108-
esn = ESN(input_data, input_size, res_size;
109-
reservoir=rand_sparse(; radius=1.2, sparsity=6 / res_size),
110-
input_layer=weighted_init,
111-
nla_type=NLAT2(),
112-
rng=rng)
108+
esn = ReservoirChain(
109+
StatefulLayer(ESNCell(input_size => res_size; init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300))),
110+
NLAT2(),
111+
Readout(res_size => input_size)
112+
)
113113
```
114114

115115
The echo state network can now be trained and tested.
116116
If not specified, the training will always be ordinary least squares regression:
117117

118118
```julia
119-
output_layer = train(esn, target_data)
120-
output = esn(Generative(predict_len), output_layer)
119+
ps, st = setup(rng, esn)
120+
ps, st = train!(esn, input_data, target_data, ps, st)
121+
output, _ = predict(esn, 1250, ps, st; initialdata=test[:, 1])
121122
```
122123

123124
The data is returned as a matrix, `output` in the code above,
@@ -126,7 +127,7 @@ The results can now be easily plotted:
126127

127128
```julia
128129
using Plots
129-
plot(transpose(output); layout=(3, 1), label="predicted")
130+
plot(transpose(output); layout=(3, 1), label="predicted");
130131
plot!(transpose(test); layout=(3, 1), label="actual")
131132
```
132133

src/ReservoirComputing.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@ include("layers/esn_cell.jl")
3232
include("generics/states.jl")
3333
include("generics/predict.jl")
3434
include("generics/linear_regression.jl")
35-
#extensions
36-
include("extensions/reca.jl")
3735
#esn
3836
include("inits/inits_components.jl")
3937
include("inits/esn_inits.jl")
40-
include("layers/esn_reservoir_drivers.jl")
4138
include("models/esn.jl")
4239
include("models/deepesn.jl")
4340
include("models/hybridesn.jl")
44-
include("models/esn_predict.jl")
41+
#extensions
42+
include("extensions/reca.jl")
43+
4544

4645

4746

@@ -58,10 +57,8 @@ export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backwar
5857
selfloop_forward_connection, simple_cycle, true_double_cycle
5958
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
6059
scale_radius!, self_loop!, simple_cycle!
61-
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
6260
export train
6361
export ESN, HybridESN, KnowledgeModel, DeepESN
64-
export Generative, Predictive, OutputLayer
6562
#reca
6663
export RECA
6764
export RandomMapping, RandomMaps

src/generics/linear_regression.jl

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,12 @@ function StandardRidge()
2929
return StandardRidge(0.0)
3030
end
3131

32-
function train!(rc::ReservoirChain, train_data::AbstractArray,
33-
target_data::AbstractArray, ps, st::NamedTuple, sr::StandardRidge=StandardRidge(0.0);
32+
function train!(rc::ReservoirChain, train_data, target_data, ps, st, sr=StandardRidge(0.0);
3433
return_states::Bool=false)
35-
states = collectstates(rc, train_data, ps, st)
36-
readout = train(sr, states, target_data)
37-
ps, st = addreadout!(rc, readout, ps, st)
38-
39-
if return_states
40-
return (ps, st), states
41-
else
42-
return ps, st
43-
end
34+
states, new_st = collectstates(rc, train_data, ps, st)
35+
W = train(sr, states, target_data)
36+
ps2, _ = addreadout!(rc, W, ps, new_st)
37+
return return_states ? ((ps2, new_st), states) : (ps2, new_st)
4438
end
4539

4640
function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray)
@@ -53,16 +47,50 @@ function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractAr
5347
return output_layer
5448
end
5549

56-
function addreadout!(rc::ReservoirChain, readout_matrix::AbstractArray, ps, st::NamedTuple) #make sure the compile infers
57-
ro_param = (; weight=readout_matrix)
58-
new_ps = (;)
59-
for ((name, layer), param) in zip(pairs(rc.layers), ps)
60-
if layer isa Readout
61-
param = merge(param, ro_param)
62-
end
63-
new_ps = merge(new_ps, (; name => param))
50+
_quote_keys(t) = Expr(:tuple, (QuoteNode(s) for s in t)...)
51+
52+
@generated function _setweight_rt(p::NamedTuple{K}, W) where {K}
53+
keys = K
54+
Kq = _quote_keys(keys)
55+
idx = findfirst(==(Symbol(:weight)), keys)
56+
57+
terms = Any[]
58+
for i in 1:length(keys)
59+
push!(terms, (idx === i) ? :(W) : :(getfield(p, $i)))
60+
end
61+
62+
if idx === nothing
63+
newK = _quote_keys((keys..., :weight))
64+
return :(NamedTuple{$newK}(($(terms...), W)))
65+
else
66+
return :(NamedTuple{$Kq}(($(terms...),)))
6467
end
65-
return new_ps, st
6668
end
6769

68-
#use a recursion to make it more compiler safe
70+
@generated function _addreadout(layers::NamedTuple{K}, ps::NamedTuple{K}, W) where {K}
71+
if length(K) == 0
72+
return :(NamedTuple())
73+
end
74+
tailK = Base.tail(K)
75+
Kq = _quote_keys(K)
76+
tailKq = _quote_keys(tailK)
77+
78+
head_val = :((getfield(layers, 1) isa Readout)
79+
? _setweight_rt(getfield(ps, 1), W)
80+
: getfield(ps, 1))
81+
82+
tail_call = :(_addreadout(NamedTuple{$tailKq}(Base.tail(layers)),
83+
NamedTuple{$tailKq}(Base.tail(ps)),
84+
W))
85+
86+
return :(NamedTuple{$Kq}(($head_val, Base.values($tail_call)...)))
87+
end
88+
89+
function addreadout!(rc::ReservoirChain,
90+
W::AbstractMatrix,
91+
ps::NamedTuple,
92+
st::NamedTuple)
93+
@assert propertynames(rc.layers) == propertynames(ps)
94+
new_ps = _addreadout(rc.layers, ps, W)
95+
return new_ps, st
96+
end

src/generics/predict.jl

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,3 @@
1-
abstract type AbstractOutputLayer end
2-
abstract type AbstractPrediction end
3-
4-
#general output layer struct
5-
struct OutputLayer{T,I,S,L} <: AbstractOutputLayer
6-
training_method::T
7-
output_matrix::I
8-
out_size::S
9-
last_value::L
10-
end
11-
12-
function Base.show(io::IO, ol::OutputLayer)
13-
print(io, "OutputLayer successfully trained with output size: ", ol.out_size)
14-
end
15-
16-
#prediction types
17-
"""
18-
Generative(prediction_len)
19-
20-
A prediction strategy that enables models to generate autonomous multi-step
21-
forecasts by recursively feeding their own outputs back as inputs for
22-
subsequent prediction steps.
23-
24-
# Parameters
25-
26-
- `prediction_len`: The number of future steps to predict.
27-
28-
# Description
29-
30-
The `Generative` prediction method allows a model to perform multi-step
31-
forecasting by using its own previous predictions as inputs for future predictions.
32-
33-
At each step, the model takes the current input, generates a prediction,
34-
and then incorporates that prediction into the input for the next step.
35-
This recursive process continues until the specified
36-
number of prediction steps (`prediction_len`) is reached.
37-
"""
38-
struct Generative{T} <: AbstractPrediction
39-
prediction_len::T
40-
end
41-
42-
struct Predictive{I,T} <: AbstractPrediction
43-
prediction_data::I
44-
prediction_len::T
45-
end
46-
47-
"""
48-
Predictive(prediction_data)
49-
50-
A prediction strategy for supervised learning tasks,
51-
where a model predicts labels based on a provided set
52-
of input features (`prediction_data`).
53-
54-
# Parameters
55-
56-
- `prediction_data`: The input data used for prediction, `feature` x `sample`
57-
58-
# Description
59-
60-
The `Predictive` prediction method uses the provided input data
61-
(`prediction_data`) to produce corresponding labels or outputs based
62-
on the learned relationships in the model.
63-
"""
64-
function Predictive(prediction_data::AbstractArray)
65-
prediction_len = size(prediction_data, 2)
66-
return Predictive(prediction_data, prediction_len)
67-
end
68-
69-
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative,
70-
x, output_layer::AbstractOutputLayer, args...;
71-
initial_conditions=output_layer.last_value)
72-
#x = last_state
73-
prediction_len = prediction.prediction_len
74-
train_method = output_layer.training_method
75-
out_size = output_layer.out_size
76-
output = output_storing(train_method, out_size, prediction_len, typeof(rc.states))
77-
out = initial_conditions
78-
79-
for i in 1:prediction_len
80-
x, x_new = next_state_prediction!(rc, x, out, i, args...)
81-
out_tmp = get_prediction(train_method, output_layer, x_new)
82-
out = store_results!(train_method, out_tmp, output, i)
83-
end
84-
85-
return output
86-
end
87-
88-
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Predictive,
89-
x, output_layer::AbstractOutputLayer, args...; kwargs...)
90-
prediction_len = prediction.prediction_len
91-
train_method = output_layer.training_method
92-
out_size = output_layer.out_size
93-
output = output_storing(train_method, out_size, prediction_len, typeof(rc.states))
94-
95-
for i in 1:prediction_len
96-
y = @view prediction.prediction_data[:, i]
97-
x, x_new = next_state_prediction!(rc, x, y, i, args...)
98-
out_tmp = get_prediction(train_method, output_layer, x_new)
99-
out = store_results!(output_layer.training_method, out_tmp, output, i)
100-
end
101-
102-
return output
103-
end
104-
105-
#linear models
106-
function get_prediction(training_method, output_layer::AbstractOutputLayer, x)
107-
return output_layer.output_matrix * x
108-
end
109-
110-
#single matrix for other training methods
111-
function output_storing(training_method, out_size, prediction_len, storing_type)
112-
return adapt(storing_type, zeros(out_size, prediction_len))
113-
end
114-
115-
#general storing -> single matrix
116-
function store_results!(training_method, out, output, i)
117-
output[:, i] = out
118-
return out
119-
end
120-
1211
function predict(rc, steps::Int, ps, st; initialdata=nothing)
1222
if initialdata == nothing
1233
initialdata = rand(Float32, 3)

src/layers/esn_cell.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,67 @@
1+
@doc raw"""
2+
ESNCell(in_dims => out_dims, [activation];
3+
use_bias=false, init_bias=rand32,
4+
init_reservoir=rand_sparse, init_input=weighted_init,
5+
init_state=randn32, leak_coefficient=1.0)
6+
7+
Echo State Network (ESN) recurrent cell with optional leaky integration.
8+
9+
## Equations
10+
11+
```math
12+
\begin{aligned}
13+
\tilde{\mathbf{h}}(t) &= \phi\!\left(\mathbf{W}_{in}\,\mathbf{x}(t) +
14+
\mathbf{W}_{res}\,\mathbf{h}(t-1) + \mathbf{b}\right) \\
15+
\mathbf{h}(t) &= (1-\alpha)\,\mathbf{h}(t-1) + \alpha\,\tilde{\mathbf{h}}(t)
16+
\end{aligned}
17+
```
18+
## Arguments
19+
20+
- `in_dims`: Input dimension.
21+
- `out_dims`: Reservoir (hidden state) dimension.
22+
- `activation`: Activation function. Default: `tanh`.
23+
24+
## Keyword arguments
25+
26+
- `use_bias`: Whether to include a bias term. Default: `false`.
27+
- `init_bias`: Initializer for the bias. Used only if `use_bias=true`.
28+
Default is `rand32`.
29+
- `init_reservoir`: Initializer for the reservoir matrix `W_res`.
30+
Default is [`rand_sparse`](@ref).
31+
- `init_input`: Initializer for the input matrix `W_in`.
32+
- `init_state`: Initializer for the hidden state when an external
33+
state is not provided. Default is `randn32`.
34+
- `leak_coefficient`: Leak rate `α ∈ (0,1]`. Default: `1.0`.
35+
36+
## Inputs
37+
38+
- **Case 1:** `x :: AbstractArray (in_dims, batch)`
39+
A fresh state is created via `init_state`; the call is forwarded to Case 2.
40+
- **Case 2:** `(x, (h,))` where `h :: AbstractArray (out_dims, batch)`
41+
Computes the update and returns the new state.
42+
43+
In both cases, the forward returns `((h_new, (h_new,)), st_out)` where `st_out`
44+
contains any updated internal state.
45+
46+
## Returns
47+
48+
- Output/hidden state `h_new :: out_dims` and state tuple `(h_new,)`.
49+
- Updated layer state (NamedTuple).
50+
51+
## Parameters
52+
53+
Created by `initialparameters(rng, esn)`:
54+
55+
- `input_matrix :: (out_dims × in_dims)` — `W_in`
56+
- `reservoir_matrix :: (out_dims × out_dims)` — `W_res`
57+
- `bias :: (out_dims,)` — present only if `use_bias=true`
58+
59+
## States
60+
61+
Created by `initialstates(rng, esn)`:
62+
63+
- `rng`: a replicated RNG used to sample initial hidden states when needed.
64+
"""
165
@concrete struct ESNCell <: AbstractReservoirRecurrentCell
266
activation
367
in_dims <: IntegerType

0 commit comments

Comments
 (0)