Skip to content

Commit eca86f2

Browse files
feat: first pass of HybridESN
1 parent 11d0843 commit eca86f2

File tree

3 files changed

+79
-119
lines changed

3 files changed

+79
-119
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ esn = ReservoirChain(
123123

124124
### 3. Train the Echo State Network
125125

126-
ReservoirCOmputing.jl builds on Lux(Core), so in order to train the model
126+
ReservoirComputing.jl builds on Lux(Core), so in order to train the model
127127
we first need to instantiate the parameters and the states:
128128

129129
```julia

src/models/deepesn.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ function DeepESN(in_dims::Int,
1616
res_dims::AbstractVector{<:Int},
1717
out_dims,
1818
activation=tanh;
19-
activations=nothing,
20-
leaks=1.0,
19+
leak_coefficient=1.0,
2120
init_reservoir=rand_sparse,
2221
init_input=weighted_init,
2322
init_bias=zeros32,
@@ -28,8 +27,8 @@ function DeepESN(in_dims::Int,
2827

2928
num_reservoirs = length(res_dims)
3029

31-
acts = activations === nothing ? _asvec(activation, num_reservoirs) : _asvec(activations, num_reservoirs)
32-
leaksv = _asvec(leaks, num_reservoirs)
30+
acts = _asvec(activation, num_reservoirs)
31+
leaksv = _asvec(leak_coefficient, num_reservoirs)
3332
inres = _asvec(init_reservoir, num_reservoirs)
3433
ininp = _asvec(init_input, num_reservoirs)
3534
inbias = _asvec(init_bias, num_reservoirs)

src/models/hybridesn.jl

Lines changed: 75 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
struct HybridESN{I,S,V,N,T,O,M,B,ST,W,IS}
2-
res_size::I
3-
train_data::S
4-
model::V
5-
nla_type::N
6-
input_matrix::T
7-
reservoir_driver::O
8-
reservoir_matrix::M
9-
bias_vector::B
10-
states_type::ST
11-
washout::W
12-
states::IS
13-
end
1+
############################
2+
# Knowledge-model wrapper #
3+
############################
144

155
struct KnowledgeModel{T,K,O,I,S,D}
166
prior_model::T
@@ -22,124 +12,95 @@ struct KnowledgeModel{T,K,O,I,S,D}
2212
end
2313

2414
"""
25-
KnowledgeModel(prior_model, u0, tspan, datasize)
26-
27-
Constructs a `Hybrid` variation of Echo State Networks (ESNs) [^Pathak2018]
28-
integrating a knowledge-based model (`prior_model`) with ESNs.
29-
30-
# Parameters
15+
KnowledgeModel(prior_model, u0, tspan, datasize)
3116
32-
- `prior_model`: A knowledge-based model function for integration with ESNs.
33-
- `u0`: Initial conditions for the model.
34-
- `tspan`: Time span as a tuple, indicating the duration for model operation.
35-
- `datasize`: The size of the data to be processed.
36-
37-
[^Pathak2018]: Jaideep Pathak et al.
38-
"Hybrid Forecasting of Chaotic Processes:
39-
Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
17+
Build a `KnowledgeModel` and precompute `model_data` on a time grid of length
18+
`datasize+1`. The extra step aligns with teacher-forced (xₜ → yₜ₊₁) usage.
4019
"""
4120
function KnowledgeModel(prior_model, u0, tspan, datasize)
4221
trange = collect(range(tspan[1], tspan[2]; length=datasize))
22+
@assert length(trange) 2 "datasize must be ≥ 2 to infer dt"
4323
dt = trange[2] - trange[1]
44-
tsteps = push!(trange, dt + trange[end])
45-
tspan_new = (tspan[1], dt + tspan[2])
46-
model_data = prior_model(u0, tspan_new, tsteps)
47-
return KnowledgeModel(prior_model, u0, tspan, dt, datasize, model_data)
24+
tsteps = push!(trange, trange[end] + dt)
25+
tspan2 = (tspan[1], tspan[2] + dt)
26+
mdl = prior_model(u0, tspan2, tsteps)
27+
return KnowledgeModel(prior_model, u0, tspan, dt, datasize, mdl)
4828
end
4929

50-
"""
51-
HybridESN(model, train_data, in_size, res_size; kwargs...)
52-
53-
Construct a Hybrid Echo State Network (ESN) model that integrates
54-
traditional Echo State Networks with a predefined knowledge model [^Pathak2018].
55-
56-
# Parameters
57-
58-
- `model`: A `KnowledgeModel` instance representing the knowledge-based model
59-
to be integrated with the ESN.
60-
- `train_data`: The training dataset used for the ESN. This data can be
61-
preprocessed or raw data depending on the nature of the problem and the
62-
preprocessing steps considered.
63-
- `in_size`: The size of the input layer, i.e., the number of input units
64-
to the ESN.
65-
- `res_size`: The size of the reservoir, i.e., the number of neurons in
66-
the hidden layer of the ESN.
67-
68-
# Optional Keyword Arguments
69-
70-
- `input_layer`: A function to initialize the input matrix.
71-
Default is `scaled_rand`.
72-
- `reservoir`: A function to initialize the reservoir matrix.
73-
Default is `rand_sparse`.
74-
- `bias`: A function to initialize the bias vector.
75-
Default is `zeros32`.
76-
- `reservoir_driver`: The driving system for the reservoir.
77-
Default is an RNN model.
78-
- `nla_type`: The type of non-linear activation used in the reservoir.
79-
Default is `NLADefault()`.
80-
- `states_type`: Defines the type of states used in the
81-
ESN. Default is `StandardStates()`.
82-
- `washout`: The number of initial timesteps to be
83-
discarded in the ESN's training phase. Default is 0.
84-
- `rng`: Random number generator used for initializing weights.
85-
Default is `Utils.default_rng()`.
86-
- `T`: The data type for the matrices (e.g., `Float32`).
87-
- `matrix_type`: The type of matrix used for storing the training data.
88-
Default is inferred from `train_data`.
30+
# Helper: forecast a KB stream for `steps` auto-regressive steps beyond tspan
31+
function kb_forecast(km::KnowledgeModel, steps::Integer)
32+
@assert steps 1
33+
t0 = km.tspan[2] + km.dt
34+
tgrid = collect(t0:km.dt:(t0+km.dt*(steps-1)))
35+
tspan = (t0, tgrid[end])
36+
u0 = km.model_data[:, end]
37+
mdl = km.prior_model(u0, tspan, [t0; tgrid[2:end]])
38+
return mdl
39+
end
8940

90-
[^Pathak2018]: Jaideep Pathak et al.
91-
"Hybrid Forecasting of Chaotic Processes:
92-
Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
93-
"""
94-
function HybridESN(model::KnowledgeModel, train_data::AbstractArray,
95-
in_size::Int, res_size::Int; input_layer=scaled_rand, reservoir=rand_sparse,
96-
bias=zeros32, reservoir_driver=RNN(),
97-
nla_type=NLADefault(),
98-
states_type=StandardStates(), washout::Int=0,
99-
rng::AbstractRNG=Utils.default_rng(), T=Float32,
100-
matrix_type=typeof(train_data))
101-
train_data = vcat(train_data, model.model_data[:, 1:(end-1)])
41+
kb_stream_train(km::KnowledgeModel, T::Integer) = km.model_data[:, 1:T]
10242

103-
in_size = size(train_data, 1)
10443

44+
# Concats a column from `stream` at each step: z_t = vcat(x_t, stream[:, i])
45+
@concrete struct AttachStream <: AbstractLuxLayer
46+
stream <: AbstractMatrix
47+
end
10548

106-
reservoir_matrix = reservoir(rng, T, res_size, res_size)
107-
#different from ESN, why?
108-
input_matrix = input_layer(rng, T, res_size, in_size)
109-
bias_vector = bias(rng, res_size)
110-
inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
111-
states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
112-
input_matrix, bias_vector)
113-
train_data = train_data[:, (washout+1):end]
49+
initialparameters(::AbstractRNG, ::AttachStream) = NamedTuple()
50+
initialstates(::AbstractRNG, ::AttachStream) = (i=1,)
11451

115-
return HybridESN(res_size, train_data, model, nla_type, input_matrix,
116-
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
117-
states)
52+
function (l::AttachStream)(x::AbstractVector, ps, st::NamedTuple)
53+
@boundscheck (st.i size(l.stream, 2)) ||
54+
throw(BoundsError(l.stream, st.i))
55+
out = vcat(x, @view l.stream[:, st.i])
56+
return out, (i=st.i + 1,)
11857
end
11958

120-
function (hesn::HybridESN)(prediction,
121-
output_layer, last_state::AbstractArray=hesn.states[
122-
:, [end]],
59+
"""
60+
HybridESN(km::KnowledgeModel,
61+
in_dims::Integer, res_dims::Integer, out_dims::Integer,
62+
activation=tanh;
63+
state_modifiers=(),
64+
readout_activation=identity,
65+
include_collect=true,
66+
kwargs...)
67+
68+
Build a hybrid ESN as a `ReservoirChain`:
69+
`StatefulLayer(ESNCell) → modifiers → AttachStream(train KB) → Readout`.
70+
"""
71+
function HybridESN(km::KnowledgeModel,
72+
in_dims::Integer, res_dims::Integer, out_dims::Integer,
73+
activation=tanh;
74+
state_modifiers=(),
75+
readout_activation=identity,
76+
include_collect::Bool=true,
12377
kwargs...)
124-
km = hesn.model
125-
pred_len = prediction.prediction_len
78+
cell = ESNCell(in_dims => res_dims, activation; kwargs...)
12679

127-
model = km.prior_model
128-
predict_tsteps = [km.tspan[2] + km.dt]
129-
[append!(predict_tsteps, predict_tsteps[end] + km.dt) for i in 1:pred_len]
130-
tspan_new = (km.tspan[2] + km.dt, predict_tsteps[end])
131-
u0 = km.model_data[:, end]
132-
model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
80+
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
81+
Tuple(state_modifiers) : (state_modifiers,)
82+
stream_train = kb_stream_train(km, km.datasize)
83+
d_kb = size(stream_train, 1)
13384

134-
return obtain_esn_prediction(hesn, prediction, last_state, output_layer,
135-
model_pred_data;
136-
kwargs...)
137-
end
85+
ro = Readout((res_dims + d_kb) => out_dims, readout_activation;
86+
include_collect=static(include_collect))
13887

139-
function train(hesn::HybridESN, target_data::AbstractArray,
140-
training_method=StandardRidge(); kwargs...)
141-
states = vcat(hesn.states, hesn.model.model_data[:, 2:end])
142-
states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end])
88+
return ReservoirChain((StatefulLayer(cell), mods..., AttachStream(stream_train), ro)...)
89+
end
14390

144-
return train(training_method, states_new, target_data; kwargs...)
91+
function with_kb_stream(rc::ReservoirChain, new_stream::AbstractMatrix)
92+
layers = rc.layers
93+
names = propertynames(layers)
94+
vals = collect(Tuple(layers))
95+
found = false
96+
for (k, v) in enumerate(vals)
97+
if v isa AttachStream
98+
vals[k] = AttachStream(new_stream)
99+
found = true
100+
break
101+
end
102+
end
103+
@assert found "No AttachStream layer found in chain."
104+
new_nt = NamedTuple{names}(Tuple(vals))
105+
return ReservoirChain(new_nt, rc.name)
145106
end

0 commit comments

Comments
 (0)