Skip to content

Commit 3f09d37

Browse files
refactor: align state modifications to new apis
1 parent 2aa3c4f commit 3f09d37

File tree

10 files changed

+182
-467
lines changed

10 files changed

+182
-467
lines changed

README.md

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,18 @@ Use the
2727
[in-development documentation](https://docs.sciml.ai/ReservoirComputing/dev/)
2828
to take a look at not yet released features.
2929

30-
## Citing
30+
## Features
3131

32-
If you use this library in your work, please cite:
32+
ReservoirComputing.jl provides layers,models, and functions to help build and train
33+
reservoir computing models. More specifically the software offers
3334

34-
```bibtex
35-
@article{martinuzzi2022reservoircomputing,
36-
author = {Francesco Martinuzzi and Chris Rackauckas and Anas Abdelrehim and Miguel D. Mahecha and Karin Mora},
37-
title = {ReservoirComputing.jl: An Efficient and Modular Library for Reservoir Computing Models},
38-
journal = {Journal of Machine Learning Research},
39-
year = {2022},
40-
volume = {23},
41-
number = {288},
42-
pages = {1--8},
43-
url = {http://jmlr.org/papers/v23/22-0611.html}
44-
}
45-
```
35+
- Base layers for reservoir computing model construction such as `ReservoirChain`,
36+
`Readout`, `Collect`, and `ESNCell`
37+
- Fully built models such as `ESN`, and `DeepESN`
38+
- 15+ reservoir initializers and 5+ input layer initializers
39+
- 5+ reservoir states modification algorithms
40+
- Sparse matrix computation through
41+
[SparseArrays.jl](https://docs.julialang.org/en/v1/stdlib/SparseArrays/)
4642

4743
## Installation
4844

@@ -63,67 +59,87 @@ Pkg.add("ReservoirComputing")
6359

6460
To illustrate the workflow of this library we will showcase
6561
how it is possible to train an ESN to learn the dynamics of the
66-
Lorenz system. As a first step we gather the data.
67-
For the `Generative` prediction we need the target data
68-
to be one step ahead of the training data:
62+
Lorenz system.
63+
64+
### 1. Generate data
65+
66+
As a general first step wee fix the random seed for reproducibilty
6967

7068
```julia
71-
using ReservoirComputing, OrdinaryDiffEq, Random
69+
using Random
7270
Random.seed!(42)
7371
rng = MersenneTwister(17)
72+
```
7473

75-
#lorenz system parameters
76-
u0 = [1.0, 0.0, 0.0]
77-
tspan = (0.0, 200.0)
78-
p = [10.0, 28.0, 8 / 3]
74+
For an autoregressive prediction we need the target data
75+
to be one step ahead of the training data:
76+
77+
```julia
78+
using OrdinaryDiffEq
7979

8080
#define lorenz system
8181
function lorenz(du, u, p, t)
8282
du[1] = p[1] * (u[2] - u[1])
8383
du[2] = u[1] * (p[2] - u[3]) - u[2]
8484
du[3] = u[1] * u[2] - p[3] * u[3]
8585
end
86+
8687
#solve and take data
87-
prob = ODEProblem(lorenz, u0, tspan, p)
88+
prob = ODEProblem(lorenz, [1.0f0, 0.0f0, 0.0f0], (0.0, 200.0), [10.0f0, 28.0f0, 8/3])
8889
data = Array(solve(prob, ABM54(); dt=0.02))
89-
9090
shift = 300
9191
train_len = 5000
9292
predict_len = 1250
9393

9494
#one step ahead for generative prediction
9595
input_data = data[:, shift:(shift + train_len - 1)]
9696
target_data = data[:, (shift + 1):(shift + train_len)]
97-
9897
test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)]
9998
```
10099

101-
Now that we have the data we can initialize the ESN with the chosen parameters.
102-
Given that this is a quick example we are going to change the least amount of
103-
possible parameters:
100+
### 2. Build Echo State Network
101+
102+
We can either use the provided `ESN` or build one from scratch.
103+
We showcase the second option:
104104

105105
```julia
106106
input_size = 3
107107
res_size = 300
108108
esn = ReservoirChain(
109-
StatefulLayer(ESNCell(input_size => res_size; init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300))),
109+
StatefulLayer(
110+
ESNCell(
111+
input_size => res_size;
112+
init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300)
113+
)
114+
),
110115
NLAT2(),
111-
Readout(res_size => input_size)
112-
) #or ESN(input_size, res_size, input_size; init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300))
116+
Readout(res_size => input_size) # autoregressive so out_dims == in_dims
117+
)
118+
# alternative:
119+
# esn = ESN(input_size, res_size, input_size;
120+
# init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300)
121+
# )
113122
```
114123

115-
The echo state network can now be trained and tested.
116-
If not specified, the training will always be ordinary least squares regression:
124+
### 3. Train the Echo State Network
125+
126+
ReservoirCOmputing.jl builds on Lux(Core), so in order to train the model
127+
we first need to instantiate the parameters and the states:
117128

118129
```julia
119130
ps, st = setup(rng, esn)
120131
ps, st = train!(esn, input_data, target_data, ps, st)
121-
output, _ = predict(esn, 1250, ps, st; initialdata=test[:, 1])
122132
```
123133

124-
The data is returned as a matrix, `output` in the code above,
125-
that contains the predicted trajectories.
126-
The results can now be easily plotted:
134+
### 4. Predict and visualize
135+
136+
We can now use the trained ESN to forecast the Lorenz system dynamics
137+
138+
```julia
139+
output, st = predict(esn, 1250, ps, st; initialdata=test[:, 1])
140+
```
141+
142+
We can now visualize the results
127143

128144
```julia
129145
using Plots
@@ -146,6 +162,23 @@ plot!(transpose(test)[:, 1], transpose(test)[:, 2], transpose(test)[:, 3]; label
146162

147163
![lorenz_attractor](https://user-images.githubusercontent.com/10376688/81470281-5a34b580-91ea-11ea-9eea-d2b266da19f4.png)
148164

165+
## Citing
166+
167+
If you use this library in your work, please cite:
168+
169+
```bibtex
170+
@article{martinuzzi2022reservoircomputing,
171+
author = {Francesco Martinuzzi and Chris Rackauckas and Anas Abdelrehim and Miguel D. Mahecha and Karin Mora},
172+
title = {ReservoirComputing.jl: An Efficient and Modular Library for Reservoir Computing Models},
173+
journal = {Journal of Machine Learning Research},
174+
year = {2022},
175+
volume = {23},
176+
number = {288},
177+
pages = {1--8},
178+
url = {http://jmlr.org/papers/v23/22-0611.html}
179+
}
180+
```
181+
149182
## Acknowledgements
150183

151184
This project was possible thanks to initial funding through

docs/make.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
using Documenter, DocumenterCitations, ReservoirComputing
1+
using Documenter, DocumenterCitations, DocumenterInterLinks, ReservoirComputing
22

3-
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
4-
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)
3+
#cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
4+
#cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)
55

66
ENV["PLOTS_TEST"] = "true"
77
ENV["GKSwstype"] = "100"

docs/pages.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pages = [
1717
"Layers"=>"api/layers.md",
1818
"Models"=>"api/models.md",
1919
"States"=>"api/states.md",
20-
"Train"=>"api/training.md",
20+
"Train"=>"api/train.md",
2121
"Predict"=>"api/predict.md",
2222
"Initializers"=>"api/inits.md",
2323
"ReCA"=>"api/reca.md"] #"References" => "references.md"

docs/src/api/states.md

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,18 @@
11
# States Modifications
22

3-
## Padding and Estension
4-
5-
```@docs
6-
StandardStates
7-
ExtendedStates
8-
PaddedStates
9-
PaddedExtendedStates
10-
```
11-
12-
## Non Linear Transformations
13-
143
```@docs
15-
NLADefault
4+
Pad
5+
Extend
166
NLAT1
177
NLAT2
188
NLAT3
199
PartialSquare
2010
ExtendedSquare
2111
```
2212

23-
## Internals
24-
25-
```@docs
26-
ReservoirComputing.create_states
27-
```
28-
2913
## References
3014

3115
```@bibliography
3216
Pages = ["states.md"]
3317
Canonical = false
34-
```
18+
```

ext/RCMLJLinearModelsExt.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@ using ReservoirComputing
33
using MLJLinearModels
44

55
function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
6-
states::AbstractArray{T}, target::AbstractArray{T};
7-
kwargs...) where {T <: Number}
8-
out_size = size(target, 1)
9-
output_layer = similar(target, size(target, 1), size(states, 1))
6+
states::AbstractMatrix{<:Real}, target::AbstractMatrix{<:Real};
7+
kwargs...)
8+
@assert size(states, 2) == size(target, 2) "states and target must share the same number of columns."
109

1110
if regressor.fit_intercept
12-
throw(ArgumentError("fit_intercept=true is not yet supported.
13-
Please add fit_intercept=false to the MLJ regressor"))
11+
throw(ArgumentError("fit_intercept=true not supported here. \
12+
Either set fit_intercept=false on the MLJ regressor, or extend addreadout! to write bias."))
1413
end
15-
16-
for i in axes(target, 1)
17-
output_layer[i, :] = MLJLinearModels.fit(regressor, states',
18-
target[i, :]; kwargs...)
14+
permuted_states = permutedims(states)
15+
output_matrix = similar(target, size(target, 1), size(states, 1))
16+
for idx in axes(target, 1)
17+
yi = vec(target[idx, :])
18+
coefs = MLJLinearModels.fit(regressor, permuted_states, yi; kwargs...)
19+
output_matrix[idx, :] = coefs
1920
end
2021

21-
return OutputLayer(regressor, output_layer, out_size, target[:, end])
22+
return output_matrix
2223
end
2324

2425
end #module

src/ReservoirComputing.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ include("extensions/reca.jl")
4646

4747
export ESNCell, StatefulLayer, Readout, ReservoirChain, Collect, collectstates, train!, predict
4848

49-
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
50-
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
49+
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
5150
export StandardRidge
5251
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
5352
modified_lm, scaled_rand, weighted_init, weighted_minimal

src/extensions/reca.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ The detail of this implementation can be found in [1].
1414
[1] Nichele, Stefano, and Andreas Molund. “Deep reservoir computing using cellular
1515
automata.” arXiv preprint arXiv:1703.02806 (2017).
1616
"""
17-
struct RandomMapping{I, T} <: AbstractInputEncoding
17+
struct RandomMapping{I,T} <: AbstractInputEncoding
1818
permutations::I
1919
expansion_size::T
2020
end
2121

22-
struct RandomMaps{T, E, G, M, S} <: AbstractEncodingData
22+
struct RandomMaps{T,E,G,M,S} <: AbstractEncodingData
2323
permutations::T
2424
expansion_size::E
2525
generations::G
@@ -44,12 +44,12 @@ arXiv preprint arXiv:1410.0162 (2014).
4444
[2] Nichele, Stefano, and Andreas Molund. “_Deep reservoir computing using cellular
4545
automata._” arXiv preprint arXiv:1703.02806 (2017).
4646
"""
47-
struct RECA{S, R, E, T, Q} <: AbstractReca
47+
struct RECA{S,R,E,N,T,Q} <: AbstractReca
4848
#res_size::I
4949
train_data::S
5050
automata::R
5151
input_encoding::E
52-
nla_type::ReservoirComputing.NonLinearAlgorithm
52+
nla_type::N
5353
states::T
5454
states_type::Q
5555
end

src/models/hybridesn.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,14 @@ traditional Echo State Networks with a predefined knowledge model [^Pathak2018].
9494
function HybridESN(model::KnowledgeModel, train_data::AbstractArray,
9595
in_size::Int, res_size::Int; input_layer=scaled_rand, reservoir=rand_sparse,
9696
bias=zeros32, reservoir_driver=RNN(),
97-
nla_type::NonLinearAlgorithm=NLADefault(),
98-
states_type::AbstractStates=StandardStates(), washout::Int=0,
97+
nla_type=NLADefault(),
98+
states_type=StandardStates(), washout::Int=0,
9999
rng::AbstractRNG=Utils.default_rng(), T=Float32,
100100
matrix_type=typeof(train_data))
101101
train_data = vcat(train_data, model.model_data[:, 1:(end-1)])
102102

103-
if states_type isa AbstractPaddedStates
104-
in_size = size(train_data, 1) + 1
105-
train_data = vcat(adapt(matrix_type, ones(1, size(train_data, 2))),
106-
train_data)
107-
else
108-
in_size = size(train_data, 1)
109-
end
103+
in_size = size(train_data, 1)
104+
110105

111106
reservoir_matrix = reservoir(rng, T, res_size, res_size)
112107
#different from ESN, why?

0 commit comments

Comments
 (0)