Skip to content

Commit 3617293

Browse files
refac: format, tests for hesn and scratch docs
1 parent 8652087 commit 3617293

28 files changed

+868
-526
lines changed

docs/make.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@ mathengine = Documenter.MathJax()
1010

1111
bib = CitationBibliography(
1212
joinpath(@__DIR__, "src", "refs.bib");
13-
style=:authoryear
13+
style = :authoryear
1414
)
1515

1616
links = InterLinks(
1717
"Lux" => "https://lux.csail.mit.edu/stable/"
1818
)
1919

20-
makedocs(; modules=[ReservoirComputing],
21-
sitename="ReservoirComputing.jl",
22-
clean=true, doctest=false, linkcheck=true,
23-
plugins=[links, bib],
24-
format=Documenter.HTML(;
20+
makedocs(; modules = [ReservoirComputing],
21+
sitename = "ReservoirComputing.jl",
22+
clean = true, doctest = false, linkcheck = true,
23+
plugins = [links, bib],
24+
format = Documenter.HTML(;
2525
mathengine,
26-
assets=["assets/favicon.ico"],
27-
canonical="https://docs.sciml.ai/ReservoirComputing/stable/"),
28-
pages=pages
26+
assets = ["assets/favicon.ico"],
27+
canonical = "https://docs.sciml.ai/ReservoirComputing/stable/"),
28+
pages = pages
2929
)
3030

31-
deploydocs(; repo="github.com/SciML/ReservoirComputing.jl.git",
32-
push_preview=true)
31+
deploydocs(; repo = "github.com/SciML/ReservoirComputing.jl.git",
32+
push_preview = true)

docs/pages.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@ pages = [
22
"ReservoirComputing.jl" => "index.md",
33
"Tutorials" => Any[
44
"Building a model from scratch" => "tutorials/scratch.md",
5-
"Chaos forecasting with an ESN"=>"tutorials/lorenz_basic.md",
5+
"Chaos forecasting with an ESN" => "tutorials/lorenz_basic.md",
66
#"Using Different Training Methods" => "esn_tutorials/different_training.md",
7-
"Deep Echo State Networks"=>"tutorials/deep_esn.md",
8-
"Hybrid Echo State Networks"=>"tutorials/hybrid.md",
7+
"Deep Echo State Networks" => "tutorials/deep_esn.md",
8+
"Hybrid Echo State Networks" => "tutorials/hybrid.md",
99
"Reservoir Computing with Cellular Automata" => "tutorials/reca.md"],
1010
"API Documentation" => Any[
11-
"Layers"=>"api/layers.md",
12-
"Models"=>"api/models.md",
13-
"States"=>"api/states.md",
14-
"Train"=>"api/train.md",
15-
"Predict"=>"api/predict.md",
16-
"Initializers"=>"api/inits.md"],
11+
"Layers" => "api/layers.md",
12+
"Models" => "api/models.md",
13+
"States" => "api/states.md",
14+
"Train" => "api/train.md",
15+
"Predict" => "api/predict.md",
16+
"Initializers" => "api/inits.md"],
1717
"References" => "references.md"
1818
]

docs/src/tutorials/scratch.md

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Building a model from scratch
2+
3+
ReservoirComputing.jl provides utilities to build reservoir reservoir
4+
computing models from scratch. In this tutorial we are going to build
5+
an echo state network ([`ESN`](@ref)) and showcase how this custom
6+
implementation is equivalent to the provided model (minus some comfort
7+
utilities)
8+
9+
## Using provided layers: ReservoirChain, ESNCell, and LinearReadout
10+
11+
The library provides a [`ReservoirChain`](@ref), which is virtually
12+
equivivalent to Lux's [`Chain`](@extref). Passing layers, or functions,
13+
to the chain will concatenate them, and will allow the flow of the input
14+
data through the model.
15+
16+
To build an ESN we also need a [`ESNCell`](@ref) to provide the ESN
17+
forward pass. However, the cell is stateless, so to keep the memoruy of
18+
the input we need to wrap it in a [`StatefulLayer`](@ref), which saves the
19+
internal state in the model states `st` and feeds it to the cell in the
20+
next step.
21+
22+
Finally, we need the trainable readout for the reservoir computing.
23+
The library provides [`LinearReadout`](@ref), a dense layer the weights
24+
of which will be trained using linear regression.
25+
26+
Putting it all together we get the following
27+
28+
```@example scratch
29+
using ReservoirComputing
30+
31+
esn_scratch = ReservoirChain(
32+
StatefulLayer(
33+
ESNCell(3=>50)
34+
),
35+
LinearReadout(50=>1)
36+
)
37+
```
38+
39+
Now, this implementation, elements naming aside, is completley equivalent to
40+
the following
41+
42+
```@example scratch
43+
esn = ESN(3, 50, 1)
44+
```
45+
46+
and we can check it initializing the two models and comparing, for instance,
47+
the weights of the input layer:
48+
49+
```@example scratch
50+
using Random
51+
Random.seed(43)
52+
53+
rng = MersenneTwister(17)
54+
ps_s, st_s = setup(rng, esn_scratch)
55+
56+
rng = MersenneTwister(17)
57+
ps, st = setup(rng, esn)
58+
59+
ps_s.layer_1.input_matrix == ps.cell.input_matrix
60+
```
61+
62+
Both the models can be trained using [`train!`](@ref), and predictions can be
63+
obtained with [`predict`](@ref). The internal states collected for linear
64+
regression are computed by traversing the [`ReservoirChain`](@ref), and
65+
stopping right before the [`LinearReadout`](@ref).
66+
67+
## Manual state collection with Collect
68+
69+
For more complicated models usually you would want to control when the state
70+
collection happens. In a [`ReservoirChain`](@ref), the collection of states is
71+
controlled by the layer [`Collect`](@ref). The role of this layer is to tell
72+
the [`collectstates`](@ref) function where to stop for state collection. All
73+
the readout layers have a `include_collect=true` keyword, which forces a
74+
[`Collect`](@ref) layer bvefore the readout. The model we wrote before can
75+
be written as
76+
77+
```@example scratch
78+
esn_scratch = ReservoirChain(
79+
StatefulLayer(
80+
ESNCell(3=>50)
81+
),
82+
Collect(),
83+
LinearReadout(50=>1; include_collect=false)
84+
)
85+
```
86+
87+
to make the collection explicit. This layer is useful in case one needs to build
88+
more complicated models such as a [`DeepESN`](@ref). We can build a deep model
89+
in multiple ways:
90+
91+
```@example scratch
92+
deepesn_scratch = ReservoirChain(
93+
StatefulLayer(
94+
ESNCell(3=>50)
95+
),
96+
StatefulLayer(
97+
ESNCell(50=>50)
98+
),
99+
StatefulLayer(
100+
ESNCell(50=>50)
101+
),
102+
Collect(),
103+
LinearReadout(50=>1; include_collect=false)
104+
)
105+
```
106+
107+
this first approach is the one provided by default in the library through
108+
[`DeepESN`](@ref). However, you could want the state collection to be after each
109+
cell
110+
111+
```@example scratch
112+
deepesn_scratch = ReservoirChain(
113+
StatefulLayer(
114+
ESNCell(3=>50)
115+
),
116+
Collect(),
117+
StatefulLayer(
118+
ESNCell(50=>50)
119+
),
120+
Collect(),
121+
StatefulLayer(
122+
ESNCell(50=>50)
123+
),
124+
Collect(),
125+
LinearReadout(50=>1; include_collect=false)
126+
)
127+
```
128+
129+
With this approach, the resulting state will be a concatenation of the states at each
130+
[`Collect`](@ref) point. So the resulting states for this architecture will be vector of
131+
size 150.
132+
133+
```@example scratch
134+
ps, st = setup(rng, deepesn_scratch)
135+
states, st = collectstates(deepesn_scratch, rand(3, 300), ps, st)
136+
size(states[:,1])
137+
```
138+
139+
This allows for even more complex constructions, where the
140+
state collection follows specific patterns
141+
142+
```@example scratch
143+
deepesn_scratch = ReservoirChain(
144+
StatefulLayer(
145+
ESNCell(3=>50)
146+
),
147+
StatefulLayer(
148+
ESNCell(50=>50)
149+
),
150+
Collect(),
151+
StatefulLayer(
152+
ESNCell(50=>50)
153+
),
154+
Collect(),
155+
LinearReadout(50=>1; include_collect=false)
156+
)
157+
```
158+
159+
Here, for instance, we have a [`Collect`](@ref) after the first two cells and then one
160+
at the very end. You can see how the size of the states is now 100:
161+
162+
```@example scratch
163+
ps, st = setup(rng, deepesn_scratch)
164+
states, st = collectstates(deepesn_scratch, rand(3, 300), ps, st)
165+
size(states[:,1])
166+
```
167+
168+
Similar approaches could be leveraged, for instance, when the data show
169+
multiscale dynamics that require specific modeling approaches.

ext/RCCellularAutomataExt.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
module RCCellularAutomataExt
22
using ReservoirComputing: RECA, RandomMapping, RandomMaps, AbstractInputEncoding,
3-
IntegerType, LinearReadout, ReservoirChain, StatefulLayer
3+
IntegerType, LinearReadout, ReservoirChain, StatefulLayer
44
import ReservoirComputing: RECACell, RECA
55
using CellularAutomata
66
using Random: randperm
77

8-
function RandomMapping(; permutations=8, expansion_size=40)
8+
function RandomMapping(; permutations = 8, expansion_size = 40)
99
RandomMapping(permutations, expansion_size)
1010
end
1111

12-
function RandomMapping(permutations; expansion_size=40)
12+
function RandomMapping(permutations; expansion_size = 40)
1313
RandomMapping(permutations, expansion_size)
1414
end
1515

1616
function create_encoding(rm::RandomMapping, in_dims::IntegerType, generations::IntegerType)
1717
maps = init_maps(in_dims, rm.permutations, rm.expansion_size)
1818
states_size = generations * rm.expansion_size * rm.permutations
1919
ca_size = rm.expansion_size * rm.permutations
20-
return RandomMaps(rm.permutations, rm.expansion_size, generations, maps, states_size, ca_size)
20+
return RandomMaps(
21+
rm.permutations, rm.expansion_size, generations, maps, states_size, ca_size)
2122
end
2223

2324
function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
@@ -26,11 +27,11 @@ function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
2627
new_tot_enc_vec = copy(tot_encoded_vector)
2728

2829
for i in 1:(rm.permutations)
29-
new_tot_enc_vec[((i-1)*rm.expansion_size+1):(i*rm.expansion_size)] = single_encoding(
30+
new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)] = single_encoding(
3031
input_vector,
31-
new_tot_enc_vec[((i-1)*rm.expansion_size+1):(i*rm.expansion_size)],
32+
new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)],
3233
rm.maps[i,
33-
:])
34+
:])
3435
end
3536

3637
return new_tot_enc_vec
@@ -84,13 +85,12 @@ function (reca::RECACell)(inp::AbstractVector, ps, st::NamedTuple)
8485
end
8586

8687
function RECA(in_dims::IntegerType,
87-
out_dims::IntegerType,
88-
automaton;
89-
input_encoding::AbstractInputEncoding=RandomMapping(),
90-
generations::Integer=8,
91-
state_modifiers=(),
92-
readout_activation=identity)
93-
88+
out_dims::IntegerType,
89+
automaton;
90+
input_encoding::AbstractInputEncoding = RandomMapping(),
91+
generations::Integer = 8,
92+
state_modifiers = (),
93+
readout_activation = identity)
9494
rm = create_encoding(input_encoding, in_dims, generations)
9595
cell = RECACell(automaton, rm)
9696

ext/RCLIBSVMExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ module RCLIBSVMExt
22

33
using LIBSVM
44
using ReservoirComputing:
5-
SVMReadout, addreadout!, ReservoirChain
5+
SVMReadout, addreadout!, ReservoirChain
66
import ReservoirComputing: train
77

88
function train(svr::LIBSVM.AbstractSVR,
9-
states::AbstractArray, target::AbstractArray)
9+
states::AbstractArray, target::AbstractArray)
1010
@assert size(states, 2) == size(target, 2) "states and target must share columns."
1111
perm_states = permutedims(states)
1212
size_target = size(target, 1)

ext/RCMLJLinearModelsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ using ReservoirComputing
33
using MLJLinearModels
44

55
function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
6-
states::AbstractMatrix{<:Real}, target::AbstractMatrix{<:Real};
7-
kwargs...)
6+
states::AbstractMatrix{<:Real}, target::AbstractMatrix{<:Real};
7+
kwargs...)
88
@assert size(states, 2) == size(target, 2) "states and target must share the same number of columns."
99

1010
if regressor.fit_intercept

src/ReservoirComputing.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ using Compat: @compat
55
using ConcreteStructs: @concrete
66
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
77
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
8-
setup, apply, replicate
8+
setup, apply, replicate
99
import LuxCore: initialparameters, initialstates, statelength, outputsize
1010
using NNlib: fast_act, sigmoid
1111
using Random: Random, AbstractRNG, randperm
1212
using Static: StaticBool, StaticInt, StaticSymbol,
13-
True, False, static, known, dynamic, StaticInteger
13+
True, False, static, known, dynamic, StaticInteger
1414
using Reexport: Reexport, @reexport
1515
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1616
@reexport using WeightInitializers
@@ -40,19 +40,19 @@ include("models/esn_hybrid.jl")
4040
include("extensions/reca.jl")
4141

4242
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
43-
train!,
44-
predict, resetcarry!
43+
train!,
44+
predict, resetcarry!
4545
export SVMReadout
4646
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
4747
export StandardRidge
4848
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
49-
modified_lm, scaled_rand, weighted_init, weighted_minimal
49+
modified_lm, scaled_rand, weighted_init, weighted_minimal
5050
export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backward,
51-
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
52-
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
53-
selfloop_forward_connection, simple_cycle, true_double_cycle
51+
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
52+
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
53+
selfloop_forward_connection, simple_cycle, true_double_cycle
5454
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
55-
scale_radius!, self_loop!, simple_cycle!
55+
scale_radius!, self_loop!, simple_cycle!
5656
export train
5757
export ESN, HybridESN, DeepESN
5858
#reca

0 commit comments

Comments
 (0)