Skip to content

Commit 602ecf0

Browse files
feat: finish ESN, add reset_carry
1 parent ed10b6d commit 602ecf0

File tree

3 files changed

+86
-127
lines changed

3 files changed

+86
-127
lines changed

src/ReservoirComputing.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@ using ConcreteStructs: @concrete
66
#using Functors
77
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal, diag
88
using LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
9-
setup, apply, replicate
9+
setup, apply, replicate
1010
import LuxCore: initialparameters, initialstates, statelength, outputsize
1111
using NNlib: fast_act, sigmoid
1212
using Random: Random, AbstractRNG, randperm
1313
using Static: StaticBool, StaticInt, StaticSymbol,
14-
True, False, static, known, dynamic, StaticInteger
14+
True, False, static, known, dynamic, StaticInteger
1515
using Reexport: Reexport, @reexport
1616
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1717
@reexport using WeightInitializers
1818
@reexport using LuxCore: setup, apply, initialparameters, initialstates
1919

20-
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
21-
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
22-
const IntegerType = Union{Integer,StaticInteger}
20+
const BoolType = Union{StaticBool, Bool, Val{true}, Val{false}}
21+
const InputType = Tuple{<:AbstractArray, Tuple{<:AbstractArray}}
22+
const IntegerType = Union{Integer, StaticInteger}
2323

2424
#@compat(public, (initialparameters)) #do I need to add intialstates/parameters in compat?
2525

@@ -42,18 +42,20 @@ include("models/hybridesn.jl")
4242
#extensions
4343
include("extensions/reca.jl")
4444

45-
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates, train!, predict
45+
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates,
46+
train!,
47+
predict, reset_carry
4648
export SVMReadout
4749
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
4850
export StandardRidge
4951
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
50-
modified_lm, scaled_rand, weighted_init, weighted_minimal
52+
modified_lm, scaled_rand, weighted_init, weighted_minimal
5153
export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backward,
52-
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
53-
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
54-
selfloop_forward_connection, simple_cycle, true_double_cycle
54+
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
55+
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
56+
selfloop_forward_connection, simple_cycle, true_double_cycle
5557
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
56-
scale_radius!, self_loop!, simple_cycle!
58+
scale_radius!, self_loop!, simple_cycle!
5759
export train
5860
export ESN, HybridESN, KnowledgeModel, DeepESN
5961
#reca

src/models/esn.jl

Lines changed: 61 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
@concrete struct ESN <: AbstractLuxContainerLayer{(:cell, :states_modifiers, :readout)}
2-
cell
3-
states_modifiers
4-
readout
2+
cell::Any
3+
states_modifiers::Any
4+
readout::Any
55
end
66

77
_wrap_layer(x) = x isa Function ? WrappedFunction(x) : x
88
_wrap_layers(xs::Tuple) = map(_wrap_layer, xs)
99

10-
function ESN(in_dims::IntegerType, res_dims::IntegerType, out_dims::IntegerType, activation=tanh;
11-
readout_activation=identity,
12-
state_modifiers=(),
13-
kwargs...)
10+
function ESN(in_dims::IntegerType, res_dims::IntegerType,
11+
out_dims::IntegerType, activation = tanh;
12+
readout_activation = identity,
13+
state_modifiers = (),
14+
kwargs...)
1415
cell = StatefulLayer(ESNCell(in_dims => res_dims, activation; kwargs...))
1516
mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
1617
Tuple(state_modifiers) : (state_modifiers,)
@@ -23,128 +24,84 @@ function initialparameters(rng::AbstractRNG, esn::ESN)
2324
ps_cell = initialparameters(rng, esn.cell)
2425
ps_mods = map(l -> initialparameters(rng, l), esn.states_modifiers) |> Tuple
2526
ps_ro = initialparameters(rng, esn.readout)
26-
return (cell=ps_cell, states_modifiers=ps_mods, readout=ps_ro)
27+
return (cell = ps_cell, states_modifiers = ps_mods, readout = ps_ro)
2728
end
2829

2930
function initialstates(rng::AbstractRNG, esn::ESN)
3031
st_cell = initialstates(rng, esn.cell)
3132
st_mods = map(l -> initialstates(rng, l), esn.states_modifiers) |> Tuple
3233
st_ro = initialstates(rng, esn.readout)
33-
return (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
34+
return (cell = st_cell, states_modifiers = st_mods, readout = st_ro)
3435
end
3536

36-
@inline function _apply_seq(layers::Tuple, x, ps::Tuple, st::Tuple)
37-
n = length(layers)
38-
new_st_parts = Vector{Any}(undef, n)
39-
@inbounds for i in 1:n
40-
x, sti = apply(layers[i], x, ps[i], st[i])
41-
new_st_parts[i] = sti
37+
@inline function _apply_seq(layers::Tuple, inp, ps::Tuple, st::Tuple)
38+
new_st_parts = Vector{Any}(undef, length(layers))
39+
for idx in eachindex(layers)
40+
inp, sti = apply(layers[idx], inp, ps[idx], st[idx])
41+
new_st_parts[idx] = sti
4242
end
43-
return x, tuple(new_st_parts...)
43+
return inp, tuple(new_st_parts...)
4444
end
4545

46-
function (m::ESN)(x, ps, st)
47-
y, st_cell = apply(m.cell, x, ps.cell, st.cell)
48-
y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, st.states_modifiers)
49-
y, st_ro = apply(m.readout, y, ps.readout, st.readout)
50-
return y, (cell=st_cell, states_modifiers=st_mods, readout=st_ro)
46+
function (esn::ESN)(inp::AbstractVector, ps, st)
47+
out, st_cell = apply(esn.cell, inp, ps.cell, st.cell)
48+
out, st_mods = _apply_seq(
49+
esn.states_modifiers, out, ps.states_modifiers, st.states_modifiers)
50+
out, st_ro = apply(esn.readout, out, ps.readout, st.readout)
51+
return out, (cell = st_cell, states_modifiers = st_mods, readout = st_ro)
5152
end
5253

53-
function reset_carry(esn::ESN, st; mode=:zeros, value=nothing, rng=nothing)
54-
# Find current carry & infer shape/type
55-
c = get(st.cell, :carry, nothing)
56-
if c === nothing
54+
function reset_carry(rng::AbstractRNG, esn::ESN, ps, st; init_carry = nothing)
55+
carry = get(st.cell, :carry, nothing)
56+
if carry === nothing
5757
outd = esn.cell.cell.out_dims
58-
T = Float32
59-
sz = (outd, 1)
58+
sz = outd
6059
else
61-
h = c[1] # carry is usually a 1-tuple (h,)
62-
T = eltype(h)
63-
sz = size(h)
60+
state = first(carry)
61+
sz = size(state, 1)
6462
end
6563

66-
new_h = begin
67-
if mode === :zeros
68-
zeros(T, sz)
69-
elseif mode === :randn
70-
rng = rng === nothing ? Random.default_rng() : rng
71-
randn(rng, T, sz...)
72-
elseif mode === :value
73-
@assert value !== nothing "Provide `value=` when mode=:value"
74-
fill(T(value), sz)
75-
else
76-
error("Unknown mode=$(mode). Use :zeros, :randn, or :value.")
77-
end
64+
if init_carry === nothing
65+
new_state = nothing
66+
else
67+
new_state = init_carry(rng, sz, 1)
68+
new_state = (new_state,)
7869
end
7970

80-
new_cell = merge(st.cell, (; carry=(new_h,)))
81-
return (cell=new_cell, states_modifiers=st.states_modifiers, readout=st.readout)
71+
new_cell = merge(st.cell, (; carry = new_state))
72+
return ps,
73+
(cell = new_cell, states_modifiers = st.states_modifiers, readout = st.readout)
8274
end
8375

84-
_set_readout_weight(ps_readout::NamedTuple, W) = merge(ps_readout, (; weight=W))
85-
86-
function train!(m::ESN, train_data::AbstractMatrix, target_data::AbstractMatrix,
87-
ps, st, train_method=StandardRidge(0.0);
88-
washout::Int=0, return_states::Bool=false)
76+
_set_readout_weight(ps_readout::NamedTuple, wro) = merge(ps_readout, (; weight = wro))
8977

78+
function collectstates(esn::ESN, data::AbstractMatrix, ps, st::NamedTuple)
9079
newst = st
91-
collected = Vector{Any}(undef, size(train_data, 2))
92-
@inbounds for (t, x) in enumerate(eachcol(train_data))
93-
y, st_cell = apply(m.cell, x, ps.cell, newst.cell)
94-
y, st_mods = _apply_seq(m.states_modifiers, y, ps.states_modifiers, newst.states_modifiers)
95-
collected[t] = copy(y)
96-
newst = (cell=st_cell, states_modifiers=st_mods, readout=newst.readout)
80+
collected = Any[]
81+
for inp in eachcol(data)
82+
cell_y, st_cell = apply(esn.cell, inp, ps.cell, newst.cell)
83+
state_t, st_mods = _apply_seq(
84+
esn.states_modifiers, cell_y, ps.states_modifiers, newst.states_modifiers)
85+
push!(collected, copy(state_t))
86+
newst = (cell = st_cell, states_modifiers = st_mods, readout = newst.readout)
9787
end
98-
states = eltype(train_data).(reduce(hcat, collected))
99-
100-
states_wo, targets_wo =
101-
washout > 0 ? _apply_washout(states, target_data, washout) : (states, target_data)
102-
103-
W = train(train_method, states_wo, targets_wo)
104-
ps2 = (cell=ps.cell,
105-
states_modifiers=ps.states_modifiers,
106-
readout=_set_readout_weight(ps.readout, W))
107-
108-
return return_states ? ((ps2, newst), states_wo) : (ps2, newst)
88+
states = eltype(data).(reduce(hcat, collected))
89+
@assert !isempty(collected)
90+
states_raw = reduce(hcat, collected)
91+
states = eltype(data).(states_raw)
92+
return states, newst
10993
end
11094

111-
_basefuncstr(x) = sprint(show, x)
95+
function train!(esn::ESN, train_data::AbstractMatrix, target_data::AbstractMatrix,
96+
ps, st, train_method = StandardRidge(0.0);
97+
washout::Int = 0, return_states::Bool = false)
98+
states, newst = collectstates(esn, train_data, ps, st)
99+
states_wo, targets_wo = washout > 0 ? _apply_washout(states, target_data, washout) :
100+
(states, target_data)
101+
wro = train(train_method, states_wo, targets_wo)
102+
ps2 = (cell = ps.cell,
103+
states_modifiers = ps.states_modifiers,
104+
readout = _set_readout_weight(ps.readout, wro))
112105

113-
_getflag(x, sym::Symbol, default=false) = begin
114-
v = known(getproperty(x, Val(sym)))
115-
v === nothing ? default : v
116-
end
117-
118-
function Base.show(io::IO, ::MIME"text/plain", rc::ReservoirChain)
119-
L = collect(pairs(rc.layers))
120-
if !isempty(L) && (L[1][2] isa StatefulLayer) && (L[end][2] isa LinearReadout)
121-
sl = L[1][2]
122-
ro = L[end][2]
123-
if sl.cell isa ESNCell
124-
esn = sl.cell
125-
mods = (length(L) > 2) ? map(x -> _basefuncstr(x[2]), L[2:end-1]) : String[]
126-
print(io, "ESN($(esn.in_dims) => $(esn.out_dims); ",
127-
"activation=", esn.activation,
128-
", leak=", esn.leak_coefficient,
129-
", readout=", ro.activation)
130-
ic = _getflag(ro, :include_collect, false)
131-
ic && print(io, ", include_collect=true")
132-
if !_getflag(esn, :use_bias, false)
133-
print(io, ", use_bias=false")
134-
end
135-
if !isempty(mods)
136-
print(io, ", modifiers=[", join(mods, ", "), "]")
137-
end
138-
print(io, ")")
139-
return
140-
end
141-
end
142-
strs = map(x -> _basefuncstr(x[2]), L)
143-
if length(strs) <= 2
144-
print(io, "ReservoirChain(", join(strs, ", "), ")")
145-
else
146-
print(io, "ReservoirChain(\n ", join(strs, ",\n "), "\n)")
147-
end
106+
return return_states ? ((ps2, newst), states_wo) : (ps2, newst)
148107
end
149-
150-
Base.show(io::IO, rc::ReservoirChain) = show(io, MIME"text/plain"(), rc)

src/train.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,21 @@ struct StandardRidge
1919
reg::Number
2020
end
2121

22-
function StandardRidge(::Type{T}, reg) where {T<:Number}
22+
function StandardRidge(::Type{T}, reg) where {T <: Number}
2323
return StandardRidge(T.(reg))
2424
end
2525

2626
function StandardRidge()
2727
return StandardRidge(0.0)
2828
end
2929

30-
3130
function _apply_washout(states::AbstractMatrix, targets::AbstractMatrix, washout::Integer)
32-
@assert washout 0 "washout must be ≥ 0"
31+
@assert washout0 "washout must be ≥ 0"
3332
len_states = size(states, 2)
34-
@assert washout < len_states "washout=$washout is ≥ number of time steps=$len_states"
33+
@assert washout<len_states "washout=$washout is ≥ number of time steps=$len_states"
3534
first_idx = washout + 1
36-
states_wo = states[:, washout+1:end]
37-
targets_wo = targets[:, washout+1:end]
35+
states_wo = states[:, (washout + 1):end]
36+
targets_wo = targets[:, (washout + 1):end]
3837
return states_wo, targets_wo
3938
end
4039

@@ -75,10 +74,11 @@ The returned state is the final state after running through the full sequence.
7574
`include_collect=true`, or insert an explicit [`Collect()`](@ref) earlier in the chain.
7675
"""
7776
function train!(rc::ReservoirChain, train_data, target_data, ps, st,
78-
train_method=StandardRidge(0.0);
79-
washout::Int=0, return_states::Bool=false)
77+
train_method = StandardRidge(0.0);
78+
washout::Int = 0, return_states::Bool = false)
8079
states, st_after = collectstates(rc, train_data, ps, st)
81-
states_wo, traindata_wo = washout > 0 ? _apply_washout(states, target_data, washout) : (states, target_data)
80+
states_wo, traindata_wo = washout > 0 ? _apply_washout(states, target_data, washout) :
81+
(states, target_data)
8282
output_matrix = train(train_method, states_wo, traindata_wo)
8383
ps2, st_after = addreadout!(rc, output_matrix, ps, st_after)
8484
return return_states ? ((ps2, st_after), states_wo) : (ps2, st_after)
@@ -134,9 +134,9 @@ end
134134
end
135135

136136
function addreadout!(rc::ReservoirChain,
137-
W::AbstractMatrix,
138-
ps::NamedTuple,
139-
st::NamedTuple)
137+
W::AbstractMatrix,
138+
ps::NamedTuple,
139+
st::NamedTuple)
140140
@assert propertynames(rc.layers) == propertynames(ps)
141141
new_ps = _addreadout(rc.layers, ps, W)
142142
return new_ps, st

0 commit comments

Comments
 (0)