1- struct HybridESN{I,S,V,N,T,O,M,B,ST,W,IS}
2- res_size:: I
3- train_data:: S
4- model:: V
5- nla_type:: N
6- input_matrix:: T
7- reservoir_driver:: O
8- reservoir_matrix:: M
9- bias_vector:: B
10- states_type:: ST
11- washout:: W
12- states:: IS
13- end
1+ # ###########################
2+ # Knowledge-model wrapper #
3+ # ###########################
144
155struct KnowledgeModel{T,K,O,I,S,D}
166 prior_model:: T
@@ -22,124 +12,95 @@ struct KnowledgeModel{T,K,O,I,S,D}
2212end
2313
2414"""
25- KnowledgeModel(prior_model, u0, tspan, datasize)
26-
27- Constructs a `Hybrid` variation of Echo State Networks (ESNs) [^Pathak2018]
28- integrating a knowledge-based model (`prior_model`) with ESNs.
29-
30- # Parameters
15+ KnowledgeModel(prior_model, u0, tspan, datasize)
3116
32- - `prior_model`: A knowledge-based model function for integration with ESNs.
33- - `u0`: Initial conditions for the model.
34- - `tspan`: Time span as a tuple, indicating the duration for model operation.
35- - `datasize`: The size of the data to be processed.
36-
37- [^Pathak2018]: Jaideep Pathak et al.
38- "Hybrid Forecasting of Chaotic Processes:
39- Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
17+ Build a `KnowledgeModel` and precompute `model_data` on a time grid of length
18+ `datasize+1`. The extra step aligns with teacher-forced (xₜ → yₜ₊₁) usage.
4019"""
4120function KnowledgeModel (prior_model, u0, tspan, datasize)
4221 trange = collect (range (tspan[1 ], tspan[2 ]; length= datasize))
22+ @assert length (trange) ≥ 2 " datasize must be ≥ 2 to infer dt"
4323 dt = trange[2 ] - trange[1 ]
44- tsteps = push! (trange, dt + trange[end ])
45- tspan_new = (tspan[1 ], dt + tspan[2 ])
46- model_data = prior_model (u0, tspan_new , tsteps)
47- return KnowledgeModel (prior_model, u0, tspan, dt, datasize, model_data )
24+ tsteps = push! (trange, trange[end ] + dt )
25+ tspan2 = (tspan[1 ], tspan[2 ] + dt )
26+ mdl = prior_model (u0, tspan2 , tsteps)
27+ return KnowledgeModel (prior_model, u0, tspan, dt, datasize, mdl )
4828end
4929
50- """
51- HybridESN(model, train_data, in_size, res_size; kwargs...)
52-
53- Construct a Hybrid Echo State Network (ESN) model that integrates
54- traditional Echo State Networks with a predefined knowledge model [^Pathak2018].
55-
56- # Parameters
57-
58- - `model`: A `KnowledgeModel` instance representing the knowledge-based model
59- to be integrated with the ESN.
60- - `train_data`: The training dataset used for the ESN. This data can be
61- preprocessed or raw data depending on the nature of the problem and the
62- preprocessing steps considered.
63- - `in_size`: The size of the input layer, i.e., the number of input units
64- to the ESN.
65- - `res_size`: The size of the reservoir, i.e., the number of neurons in
66- the hidden layer of the ESN.
67-
68- # Optional Keyword Arguments
69-
70- - `input_layer`: A function to initialize the input matrix.
71- Default is `scaled_rand`.
72- - `reservoir`: A function to initialize the reservoir matrix.
73- Default is `rand_sparse`.
74- - `bias`: A function to initialize the bias vector.
75- Default is `zeros32`.
76- - `reservoir_driver`: The driving system for the reservoir.
77- Default is an RNN model.
78- - `nla_type`: The type of non-linear activation used in the reservoir.
79- Default is `NLADefault()`.
80- - `states_type`: Defines the type of states used in the
81- ESN. Default is `StandardStates()`.
82- - `washout`: The number of initial timesteps to be
83- discarded in the ESN's training phase. Default is 0.
84- - `rng`: Random number generator used for initializing weights.
85- Default is `Utils.default_rng()`.
86- - `T`: The data type for the matrices (e.g., `Float32`).
87- - `matrix_type`: The type of matrix used for storing the training data.
88- Default is inferred from `train_data`.
30+ # Helper: forecast a KB stream for `steps` auto-regressive steps beyond tspan
31+ function kb_forecast (km:: KnowledgeModel , steps:: Integer )
32+ @assert steps ≥ 1
33+ t0 = km. tspan[2 ] + km. dt
34+ tgrid = collect (t0: km. dt: (t0+ km. dt* (steps- 1 )))
35+ tspan = (t0, tgrid[end ])
36+ u0 = km. model_data[:, end ]
37+ mdl = km. prior_model (u0, tspan, [t0; tgrid[2 : end ]])
38+ return mdl
39+ end
8940
90- [^Pathak2018]: Jaideep Pathak et al.
91- "Hybrid Forecasting of Chaotic Processes:
92- Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
93- """
94- function HybridESN (model:: KnowledgeModel , train_data:: AbstractArray ,
95- in_size:: Int , res_size:: Int ; input_layer= scaled_rand, reservoir= rand_sparse,
96- bias= zeros32, reservoir_driver= RNN (),
97- nla_type= NLADefault (),
98- states_type= StandardStates (), washout:: Int = 0 ,
99- rng:: AbstractRNG = Utils. default_rng (), T= Float32,
100- matrix_type= typeof (train_data))
101- train_data = vcat (train_data, model. model_data[:, 1 : (end - 1 )])
41+ kb_stream_train (km:: KnowledgeModel , T:: Integer ) = km. model_data[:, 1 : T]
10242
103- in_size = size (train_data, 1 )
10443
44+ # Concats a column from `stream` at each step: z_t = vcat(x_t, stream[:, i])
45+ @concrete struct AttachStream <: AbstractLuxLayer
46+ stream <: AbstractMatrix
47+ end
10548
106- reservoir_matrix = reservoir (rng, T, res_size, res_size)
107- # different from ESN, why?
108- input_matrix = input_layer (rng, T, res_size, in_size)
109- bias_vector = bias (rng, res_size)
110- inner_res_driver = reservoir_driver_params (reservoir_driver, res_size, in_size)
111- states = create_states (inner_res_driver, train_data, washout, reservoir_matrix,
112- input_matrix, bias_vector)
113- train_data = train_data[:, (washout+ 1 ): end ]
49+ initialparameters (:: AbstractRNG , :: AttachStream ) = NamedTuple ()
50+ initialstates (:: AbstractRNG , :: AttachStream ) = (i= 1 ,)
11451
115- return HybridESN (res_size, train_data, model, nla_type, input_matrix,
116- inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
117- states)
52+ function (l:: AttachStream )(x:: AbstractVector , ps, st:: NamedTuple )
53+ @boundscheck (st. i ≤ size (l. stream, 2 )) ||
54+ throw (BoundsError (l. stream, st. i))
55+ out = vcat (x, @view l. stream[:, st. i])
56+ return out, (i= st. i + 1 ,)
11857end
11958
120- function (hesn:: HybridESN )(prediction,
121- output_layer, last_state:: AbstractArray = hesn. states[
122- :, [end ]],
59+ """
60+ HybridESN(km::KnowledgeModel,
61+ in_dims::Integer, res_dims::Integer, out_dims::Integer,
62+ activation=tanh;
63+ state_modifiers=(),
64+ readout_activation=identity,
65+ include_collect=true,
66+ kwargs...)
67+
68+ Build a hybrid ESN as a `ReservoirChain`:
69+ `StatefulLayer(ESNCell) → modifiers → AttachStream(train KB) → Readout`.
70+ """
71+ function HybridESN (km:: KnowledgeModel ,
72+ in_dims:: Integer , res_dims:: Integer , out_dims:: Integer ,
73+ activation= tanh;
74+ state_modifiers= (),
75+ readout_activation= identity,
76+ include_collect:: Bool = true ,
12377 kwargs... )
124- km = hesn. model
125- pred_len = prediction. prediction_len
78+ cell = ESNCell (in_dims => res_dims, activation; kwargs... )
12679
127- model = km. prior_model
128- predict_tsteps = [km. tspan[2 ] + km. dt]
129- [append! (predict_tsteps, predict_tsteps[end ] + km. dt) for i in 1 : pred_len]
130- tspan_new = (km. tspan[2 ] + km. dt, predict_tsteps[end ])
131- u0 = km. model_data[:, end ]
132- model_pred_data = model (u0, tspan_new, predict_tsteps)[:, 2 : end ]
80+ mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
81+ Tuple (state_modifiers) : (state_modifiers,)
82+ stream_train = kb_stream_train (km, km. datasize)
83+ d_kb = size (stream_train, 1 )
13384
134- return obtain_esn_prediction (hesn, prediction, last_state, output_layer,
135- model_pred_data;
136- kwargs... )
137- end
85+ ro = Readout ((res_dims + d_kb) => out_dims, readout_activation;
86+ include_collect= static (include_collect))
13887
139- function train (hesn:: HybridESN , target_data:: AbstractArray ,
140- training_method= StandardRidge (); kwargs... )
141- states = vcat (hesn. states, hesn. model. model_data[:, 2 : end ])
142- states_new = hesn. states_type (hesn. nla_type, states, hesn. train_data[:, 1 : end ])
88+ return ReservoirChain ((StatefulLayer (cell), mods... , AttachStream (stream_train), ro). .. )
89+ end
14390
144- return train (training_method, states_new, target_data; kwargs... )
91+ function with_kb_stream (rc:: ReservoirChain , new_stream:: AbstractMatrix )
92+ layers = rc. layers
93+ names = propertynames (layers)
94+ vals = collect (Tuple (layers))
95+ found = false
96+ for (k, v) in enumerate (vals)
97+ if v isa AttachStream
98+ vals[k] = AttachStream (new_stream)
99+ found = true
100+ break
101+ end
102+ end
103+ @assert found " No AttachStream layer found in chain."
104+ new_nt = NamedTuple {names} (Tuple (vals))
105+ return ReservoirChain (new_nt, rc. name)
145106end
0 commit comments