Skip to content

Commit 26ef75b

Browse files
tests: add tests for layers LinearReadout and ESNCell
1 parent 27c1213 commit 26ef75b

File tree

13 files changed

+326
-36
lines changed

13 files changed

+326
-36
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Francesco Martinuzzi"]
44
version = "0.11.4"
55

66
[deps]
7-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
87
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
98
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
109
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
@@ -29,7 +28,6 @@ RCMLJLinearModelsExt = "MLJLinearModels"
2928
RCSparseArraysExt = "SparseArrays"
3029

3130
[compat]
32-
Adapt = "4.1.1"
3331
Aqua = "0.8"
3432
ArrayInterface = "7.19.0"
3533
CellularAutomata = "0.0.6"
@@ -64,4 +62,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
6462
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6563

6664
[targets]
67-
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "CellularAutomata", "JET"]
65+
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "JET"]

docs/src/api/layers.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
## Base Layers
44
```@doc
55
ReservoirChain
6-
Readout
76
Collect
87
StatefulLayer
8+
LinearReadout
9+
SVMReadout
910
```
1011

1112
## Echo State Networks

docs/src/api/models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
```@docs
44
ESN
55
DeepESN
6+
HybridESN
67
```

src/ReservoirComputing.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module ReservoirComputing
22

3-
using Adapt: adapt
43
using ArrayInterface: ArrayInterface
54
using Compat: @compat
65
using ConcreteStructs: @concrete
@@ -15,13 +14,13 @@ using Static: StaticBool, StaticInt, StaticSymbol,
1514
using Reexport: Reexport, @reexport
1615
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1716
@reexport using WeightInitializers
18-
@reexport using LuxCore: setup, apply
17+
@reexport using LuxCore: setup, apply, initialparameters, initialstates
1918

2019
const BoolType = Union{StaticBool,Bool,Val{true},Val{false}}
2120
const InputType = Tuple{<:AbstractArray,Tuple{<:AbstractArray}}
2221
const IntegerType = Union{Integer,StaticInteger}
2322

24-
#@compat(public, (create_states)) #do I need to add intialstates/parameters in compat?
23+
#@compat(public, (initialparameters)) #do I need to add intialstates/parameters in compat?
2524

2625
#layers
2726
include("layers/basic.jl")
@@ -42,7 +41,7 @@ include("models/hybridesn.jl")
4241
#extensions
4342
include("extensions/reca.jl")
4443

45-
export ESNCell, StatefulLayer, Readout, ReservoirChain, Collect, collectstates, train!, predict
44+
export ESNCell, StatefulLayer, LinearReadout, ReservoirChain, Collect, collectstates, train!, predict
4645
export SVMReadout
4746
export Pad, Extend, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
4847
export StandardRidge

src/inits/esn_inits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
787787
reservoir_matrix = create_diag(rng, T, res_dim, T(max_value);
788788
sorted=sorted, reverse_sort=reverse_sort)
789789

790-
tmp = get_sparsity(R, res_dim)
790+
tmp = get_sparsity(reservoir_matrix, res_dim)
791791
while tmp <= sparsity
792792
i = rand_range(rng, res_dim)
793793
j = rand_range(rng, res_dim)
@@ -813,7 +813,7 @@ end
813813

814814
function create_diag(rng::AbstractRNG, ::Type{T}, res_dim::Integer, max_value::Number;
815815
sorted::Bool=true, reverse_sort::Bool=false) where {T<:Number}
816-
diag_matrix = DeviceAgnostic.rand(rng, T, Int(n)) .* T(max_value)
816+
diag_matrix = DeviceAgnostic.rand(rng, T, Int(res_dim)) .* T(max_value)
817817

818818
if sorted
819819
sort!(diag_matrix)

src/layers/basic.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ abstract type AbstractReservoirCollectionLayer <: AbstractLuxLayer end
22
abstract type AbstractReservoirRecurrentCell <: AbstractLuxLayer end
33
abstract type AbstractReservoirTrainableLayer <: AbstractLuxLayer end
44

5-
### Readout
5+
### LinearReadout
66
# adapted from lux layers/basic Dense
77
@doc raw"""
8-
Readout(in_dims => out_dims, [activation];
8+
LinearReadout(in_dims => out_dims, [activation];
99
use_bias=false, include_collect=true)
1010
1111
Linear readout layer with optional bias and elementwise activation. Intended as
@@ -48,7 +48,7 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
4848
Otherwise training may operate on the post-readout signal,
4949
which is usually unintended.
5050
"""
51-
@concrete struct Readout <: AbstractReservoirTrainableLayer
51+
@concrete struct LinearReadout <: AbstractReservoirTrainableLayer
5252
activation
5353
in_dims <: IntegerType
5454
out_dims <: IntegerType
@@ -58,32 +58,32 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
5858
include_collect <: StaticBool
5959
end
6060

61-
function Readout(mapping::Pair{<:IntegerType,<:IntegerType}, activation=identity; kwargs...)
62-
return Readout(first(mapping), last(mapping), activation; kwargs...)
61+
function LinearReadout(mapping::Pair{<:IntegerType,<:IntegerType}, activation=identity; kwargs...)
62+
return LinearReadout(first(mapping), last(mapping), activation; kwargs...)
6363
end
6464

65-
function Readout(in_dims::IntegerType, out_dims::IntegerType, activation=identity;
65+
function LinearReadout(in_dims::IntegerType, out_dims::IntegerType, activation=identity;
6666
init_weight=rand32, init_bias=rand32, include_collect::BoolType=True(),
6767
use_bias::BoolType=False())
68-
return Readout(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias), static(include_collect))
68+
return LinearReadout(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias), static(include_collect))
6969
end
7070

71-
function initialparameters(rng::AbstractRNG, ro::Readout)
71+
function initialparameters(rng::AbstractRNG, ro::LinearReadout)
7272
weight = ro.init_weight(rng, ro.out_dims, ro.in_dims)
7373

7474
if has_bias(ro)
75-
return (; weight, bias=ro.init_bias(rng, Float32, ro.out_dims))
75+
return (; weight, bias=ro.init_bias(rng, ro.out_dims))
7676
else
7777
return (; weight)
7878
end
7979
end
8080

81-
parameterlength(ro::Readout) = ro.out_dims * ro.in_dims + has_bias(ro) * ro.out_dims
82-
statelength(ro::Readout) = 0
81+
parameterlength(ro::LinearReadout) = ro.out_dims * ro.in_dims + has_bias(ro) * ro.out_dims
82+
statelength(ro::LinearReadout) = 0
8383

84-
outputsize(ro::Readout, _, ::AbstractRNG) = (ro.out_dims,)
84+
outputsize(ro::LinearReadout, _, ::AbstractRNG) = (ro.out_dims,)
8585

86-
function (ro::Readout)(inp::AbstractArray, ps, st::NamedTuple)
86+
function (ro::LinearReadout)(inp::AbstractArray, ps, st::NamedTuple)
8787
out_tmp = ps.weight * inp
8888
if has_bias(ro)
8989
out_tmp += ps.bias
@@ -92,8 +92,8 @@ function (ro::Readout)(inp::AbstractArray, ps, st::NamedTuple)
9292
return output, st
9393
end
9494

95-
function Base.show(io::IO, ro::Readout)
96-
print(io, "Readout($(ro.in_dims) => $(ro.out_dims)")
95+
function Base.show(io::IO, ro::LinearReadout)
96+
print(io, "LinearReadout($(ro.in_dims) => $(ro.out_dims)")
9797
(ro.activation == identity) || print(io, ", $(ro.activation)")
9898
has_bias(ro) || print(io, ", use_bias=false")
9999
ic = known(getproperty(ro, Val(:include_collect)))
@@ -136,7 +136,7 @@ vectors are concatenated with `vcat` in order of appearance.
136136
137137
## Notes
138138
139-
- When used with a single `Collect()` before a [`Readout`](@ref), training uses exactly
139+
- When used with a single `Collect()` before a [`LinearReadout`](@ref), training uses exactly
140140
the tensor right before the readout (e.g., the reservoir state).
141141
- With **multiple** `Collect()` layers (e.g., after different submodules), the
142142
per-step features are `vcat`-ed in chain order to form one feature vector.
@@ -150,7 +150,7 @@ vectors are concatenated with `vcat` in order of appearance.
150150
StatefulLayer(ESNCell(3 => 300)),
151151
NLAT2(),
152152
Collect(), # <-- collect the 300-dim reservoir after NLAT2
153-
Readout(300 => 3; include_collect=false) # <-- toggle off the default Collect()
153+
LinearReadout(300 => 3; include_collect=false) # <-- toggle off the default Collect()
154154
)
155155
```
156156
"""
@@ -173,7 +173,7 @@ in a step, the feature defaults to the final vector exiting the chain for
173173
that time step.
174174
175175
!!! note
176-
If your [`Readout`](@ref) layer was created with `include_collect=true`
176+
If your [`LinearReadout`](@ref) layer was created with `include_collect=true`
177177
(default behaviour), a collection point is placed immediately before the readout,
178178
so the collected features are the inputs to the readout.
179179
@@ -209,7 +209,12 @@ function collectstates(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st::Named
209209
end
210210
push!(collected, state_vec === nothing ? copy(inp_tmp) : state_vec)
211211
end
212-
states = eltype(data).(reduce(hcat, collected))
212+
@assert !isempty(collected)
213+
firstcol = collected[1]
214+
Tcol = eltype(firstcol)
215+
empty_mat = zeros(Tcol, length(firstcol), 0)
216+
states_raw = reduce(hcat, collected; init=empty_mat)
217+
states = eltype(data).(states_raw)
213218
return states, newst
214219
end
215220

src/layers/lux_layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ end
7070

7171
wrap_functions_in_chain_call(x) = x
7272

73-
_readout_include_collect(ro::Readout) = begin
73+
_readout_include_collect(ro::LinearReadout) = begin
7474
res = known(getproperty(ro, Val(:include_collect)))
7575
res === nothing ? false : res
7676
end
7777

78-
function wrap_functions_in_chain_call(ro::Readout)
78+
function wrap_functions_in_chain_call(ro::LinearReadout)
7979
return _readout_include_collect(ro) ? (Collect(), ro) : ro
8080
end
8181

src/models/deepesn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function DeepESN(in_dims::Int,
5353
end
5454
prev = res_dims[res]
5555
end
56-
ro = Readout(prev => out_dims, readout_activation)
56+
ro = LinearReadout(prev => out_dims, readout_activation)
5757
return ReservoirChain((layers..., ro)...)
5858
end
5959

src/models/esn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function ESN(in_dims::IntegerType, res_dims::IntegerType, out_dims::IntegerType,
55
cell = ESNCell(in_dims => res_dims, activation; kwargs...)
66
mods = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
77
Tuple(state_modifiers) : (state_modifiers,)
8-
ro = Readout(res_dims => out_dims, readout_activation)
8+
ro = LinearReadout(res_dims => out_dims, readout_activation)
99
return ReservoirChain((StatefulLayer(cell), mods..., ro)...)
1010
end
1111

@@ -18,7 +18,7 @@ end
1818

1919
function Base.show(io::IO, ::MIME"text/plain", rc::ReservoirChain)
2020
L = collect(pairs(rc.layers))
21-
if !isempty(L) && (L[1][2] isa StatefulLayer) && (L[end][2] isa Readout)
21+
if !isempty(L) && (L[1][2] isa StatefulLayer) && (L[end][2] isa LinearReadout)
2222
sl = L[1][2]
2323
ro = L[end][2]
2424
if sl.cell isa ESNCell

src/models/hybridesn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
kwargs...)
6767
6868
Build a hybrid ESN as a `ReservoirChain`:
69-
`StatefulLayer(ESNCell) → modifiers → AttachStream(train KB) → Readout`.
69+
`StatefulLayer(ESNCell) → modifiers → AttachStream(train KB) → LinearReadout`.
7070
"""
7171
function HybridESN(km::KnowledgeModel,
7272
in_dims::Integer, res_dims::Integer, out_dims::Integer,
@@ -82,7 +82,7 @@ function HybridESN(km::KnowledgeModel,
8282
stream_train = kb_stream_train(km, km.datasize)
8383
d_kb = size(stream_train, 1)
8484

85-
ro = Readout((res_dims + d_kb) => out_dims, readout_activation;
85+
ro = LinearReadout((res_dims + d_kb) => out_dims, readout_activation;
8686
include_collect=static(include_collect))
8787

8888
return ReservoirChain((StatefulLayer(cell), mods..., AttachStream(stream_train), ro)...)

0 commit comments

Comments
 (0)