Skip to content

Commit 63384da

Browse files
fix: small fixes and docs
1 parent 0333d44 commit 63384da

File tree

5 files changed

+52
-19
lines changed

5 files changed

+52
-19
lines changed

src/inits/esn_inits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
791791
while tmp <= sparsity
792792
i = rand_range(rng, res_dim)
793793
j = rand_range(rng, res_dim)
794-
θ = DeviceAgnostic.rand(rng, T) * T(2) - T(1)
794+
θ = DeviceAgnostic.rand(rng, T) * T(2) .- T(1)
795795
reservoir_matrix = reservoir_matrix * create_qmatrix(rng, T, res_dim, i, j, θ)
796796
tmp = get_sparsity(reservoir_matrix, res_dim)
797797
end

src/layers/esn_cell.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
abstract type AbstractEchoStateNetworkCell <: AbstractReservoirRecurrentCell end
2+
13
@doc raw"""
24
ESNCell(in_dims => out_dims, [activation];
35
use_bias=false, init_bias=rand32,
@@ -63,7 +65,7 @@ Created by `initialstates(rng, esn)`:
6365
6466
- `rng`: a replicated RNG used to sample initial hidden states when needed.
6567
"""
66-
@concrete struct ESNCell <: AbstractReservoirRecurrentCell
68+
@concrete struct ESNCell <: AbstractEchoStateNetworkCell
6769
activation::Any
6870
in_dims <: IntegerType
6971
out_dims <: IntegerType
@@ -76,15 +78,15 @@ Created by `initialstates(rng, esn)`:
7678
use_bias <: StaticBool
7779
end
7880

79-
function ESNCell(
80-
(in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, activation = tanh;
81-
use_bias::BoolType = False(), init_bias = zeros32, init_reservoir = rand_sparse,
82-
init_input = scaled_rand, init_state = randn32, leak_coefficient = 1.0)
81+
function ESNCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType},
82+
activation = tanh; use_bias::BoolType = False(), init_bias = zeros32,
83+
init_reservoir = rand_sparse, init_input = scaled_rand,
84+
init_state = randn32, leak_coefficient = 1.0)
8385
return ESNCell(activation, in_dims, out_dims, init_bias, init_reservoir,
8486
init_input, init_state, leak_coefficient, use_bias)
8587
end
8688

87-
function initialparameters(rng::AbstractRNG, esn::ESNCell)
89+
function initialparameters(rng::AbstractRNG, esn::AbstractEchoStateNetworkCell)
8890
ps = (input_matrix = esn.init_input(rng, esn.out_dims, esn.in_dims),
8991
reservoir_matrix = esn.init_reservoir(rng, esn.out_dims, esn.out_dims))
9092
if has_bias(esn)
@@ -93,11 +95,11 @@ function initialparameters(rng::AbstractRNG, esn::ESNCell)
9395
return ps
9496
end
9597

96-
function initialstates(rng::AbstractRNG, esn::ESNCell)
98+
function initialstates(rng::AbstractRNG, esn::AbstractEchoStateNetworkCell)
9799
return (rng = sample_replicate(rng),)
98100
end
99101

100-
function (esn::ESNCell)(inp::AbstractArray, ps, st::NamedTuple)
102+
function (esn::AbstractEchoStateNetworkCell)(inp::AbstractArray, ps, st::NamedTuple)
101103
rng = replicate(st.rng)
102104
hidden_state = init_hidden_state(rng, esn, inp)
103105
return esn((inp, (hidden_state,)), ps, merge(st, (; rng)))

src/reservoircomputer.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
2+
@doc raw"""
3+
ReservoirComputer(reservoir, states_modifiers, readout)
4+
5+
Generic reservoir-computing container that wires together:
6+
1) a `reservoir` (any Lux-compatible layer producing features),
7+
2) zero or more `states_modifiers` applied sequentially to the reservoir features,
8+
3) a `readout` layer (typically [`LinearReadout`](@ref)).
9+
10+
The container exposes a standard `(x, ps, st) -> (y, st′)` interface and
11+
utility functions to initialize parameters/states, stream sequences to collect
12+
features, and install trained readout weights.
13+
14+
## Arguments
15+
16+
- `reservoir`: a layer that consumes inputs and produces feature vectors.
17+
- `states_modifiers`: a tuple (or vector converted to `Tuple`) of layers applied
18+
after the reservoir (may be empty).
19+
- `readout`: the final trainable layer mapping features to outputs.
20+
21+
## Inputs
22+
23+
- `x`: input to the reservoir (shape determined by the reservoir).
24+
- `ps`: reservoir computing parameters.
25+
- `st`: reservoir computing states.
26+
27+
## Returns
28+
29+
- `(y, st′)` where `y` is the readout output and `st′` contains the updated
30+
states of the reservoir, modifiers, and readout.
31+
"""
132
@concrete struct ReservoirComputer <: AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)}
233
reservoir::Any
334
states_modifiers::Any

test/esn/test_esn.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ end
3434

3535
ps, st = setup(rng, esn)
3636

37-
@test haskey(ps, :cell)
38-
@test haskey(ps.cell, :input_matrix)
39-
@test haskey(ps.cell, :reservoir_matrix)
40-
@test !haskey(ps.cell, :bias)
41-
@test size(ps.cell.input_matrix) == (res_dims, in_dims)
42-
@test size(ps.cell.reservoir_matrix) == (res_dims, res_dims)
37+
@test haskey(ps, :reservoir)
38+
@test haskey(ps.reservoir, :input_matrix)
39+
@test haskey(ps.reservoir, :reservoir_matrix)
40+
@test !haskey(ps.reservoir, :bias)
41+
@test size(ps.reservoir.input_matrix) == (res_dims, in_dims)
42+
@test size(ps.reservoir.reservoir_matrix) == (res_dims, res_dims)
4343

4444
@test haskey(ps, :readout)
4545
@test haskey(ps.readout, :weight)
4646
@test size(ps.readout.weight) == (out_dims, res_dims)
4747

48-
@test haskey(st, :cell)
48+
@test haskey(st, :reservoir)
4949
@test haskey(st, :states_modifiers)
5050
@test haskey(st, :readout)
5151
@test st.states_modifiers isa Tuple
@@ -72,7 +72,7 @@ end
7272

7373
@test size(Y) == (D, 1)
7474
@test vec(Y) x
75-
@test haskey(st2, :cell) && haskey(st2, :states_modifiers) && haskey(st2, :readout)
75+
@test haskey(st2, :reservoir) && haskey(st2, :states_modifiers) && haskey(st2, :readout)
7676
end
7777

7878
@testset "ESN: forward (batch matrix) with identity pipeline -> Y == X" begin
@@ -155,7 +155,7 @@ end
155155
y, st2 = esn(x, ps, st)
156156

157157
@test y x
158-
@test haskey(st2, :cell)
158+
@test haskey(st2, :reservoir)
159159
@test haskey(st2, :states_modifiers)
160160
@test haskey(st2, :readout)
161161
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
@testset "Common Utilities" begin
55
@safetestset "Quality Assurance" include("qa.jl")
6-
#@safetestset "States" include("test_states.jl")
6+
@safetestset "States" include("test_states.jl")
77
end
88

99
@testset "Layers" begin

0 commit comments

Comments
 (0)