Skip to content

Commit d9f98cb

Browse files
fix: test and docs fixes
1 parent 63384da commit d9f98cb

File tree

14 files changed

+42
-46
lines changed

14 files changed

+42
-46
lines changed

docs/src/tutorials/scratch.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ps_s, st_s = setup(rng, esn_scratch)
5656
rng = MersenneTwister(17)
5757
ps, st = setup(rng, esn)
5858
59-
ps_s.layer_1.input_matrix == ps.cell.input_matrix
59+
ps_s.layer_1.input_matrix == ps.reservoir.input_matrix
6060
```
6161

6262
Both the models can be trained using [`train!`](@ref), and predictions can be

ext/RCCellularAutomataExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
3030
end
3131

3232
function single_encoding(input_vector, encoded_vector, map)
33-
@assert length(map) == length(input_vector) """
34-
RandomMaps mismatch: map length = $(length(map)) but input length = $(length(input_vector)).
35-
(Build RandomMaps with in_dims = size(input, 1) used at training time.)
36-
"""
33+
@assert length(map)==length(input_vector) """
34+
RandomMaps mismatch: map length = $(length(map)) but input length = $(length(input_vector)).
35+
(Build RandomMaps with in_dims = size(input, 1) used at training time.)
36+
"""
3737
new_enc_vec = copy(encoded_vector)
3838

3939
for i in 1:size(input_vector, 1)

ext/RCLIBSVMExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ReservoirComputing: train
77

88
function train(svr::LIBSVM.AbstractSVR,
99
states::AbstractArray, target::AbstractArray)
10-
@assert size(states, 2) == size(target, 2) "states and target must share columns."
10+
@assert size(states, 2)==size(target, 2) "states and target must share columns."
1111
perm_states = permutedims(states)
1212
size_target = size(target, 1)
1313

ext/RCMLJLinearModelsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using MLJLinearModels
55
function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
66
states::AbstractMatrix{<:Real}, target::AbstractMatrix{<:Real};
77
kwargs...)
8-
@assert size(states, 2) == size(target, 2) "states and target must share the same number of columns."
8+
@assert size(states, 2)==size(target, 2) "states and target must share the same number of columns."
99

1010
if regressor.fit_intercept
1111
throw(ArgumentError("fit_intercept=true not supported here. \

src/inits/esn_inits.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ end
8686
function apply_scale!(input_matrix::AbstractArray,
8787
scaling::Tuple{<:Number, <:Number}, ::Type{T}) where {T}
8888
lower, upper = T(scaling[1]), T(scaling[2])
89-
@assert lower < upper "lower < upper required"
89+
@assert lower<upper "lower < upper required"
9090
scale = upper - lower
9191
@. input_matrix = input_matrix * scale + lower
9292
return input_matrix
@@ -95,7 +95,7 @@ end
9595
function apply_scale!(input_matrix::AbstractMatrix,
9696
scaling::AbstractVector, ::Type{T}) where {T <: Number}
9797
ncols = size(input_matrix, 2)
98-
@assert length(scaling) == ncols "need one scaling per column"
98+
@assert length(scaling)==ncols "need one scaling per column"
9999
for (idx, col) in enumerate(eachcol(input_matrix))
100100
apply_scale!(col, scaling[idx], T)
101101
end
@@ -2020,9 +2020,9 @@ function block_diagonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
20202020
\n"
20212021
end
20222022
weights = isa(weight, AbstractVector) ? T.(weight) : fill(T(weight), num_blocks)
2023-
@assert length(weights) == num_blocks "
2024-
weight vector must have length = number of blocks
2025-
"
2023+
@assert length(weights)==num_blocks "
2024+
weight vector must have length = number of blocks
2025+
"
20262026
reservoir_matrix = DeviceAgnostic.zeros(rng, T, n_rows, n_cols)
20272027
for block in 1:num_blocks
20282028
row_start = (block - 1) * block_size + 1

src/layers/basic.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ before this layer (logically inserting a [`Collect`](@ref) right before it).
4444
4545
- In ESN workflows, readout weights are typically replaced via ridge regression in
4646
[`train!`](@ref). Therefore, how `LinearReadout` gets initialized is of no consequence.
47-
Additionally, the dimesions will also not be taken into account, as [`train!`](@ref)
47+
Additionally, the dimensions will also not be taken into account, as [`train!`](@ref)
4848
will replace the weights.
4949
- If you set `include_collect=false`, make sure a [`Collect`](@ref) appears earlier in the chain.
5050
Otherwise training may operate on the post-readout signal,
@@ -60,7 +60,8 @@ before this layer (logically inserting a [`Collect`](@ref) right before it).
6060
include_collect <: StaticBool
6161
end
6262

63-
function LinearReadout(mapping::Pair{<:IntegerType, <:IntegerType}, activation = identity; kwargs...)
63+
function LinearReadout(
64+
mapping::Pair{<:IntegerType, <:IntegerType}, activation = identity; kwargs...)
6465
return LinearReadout(first(mapping), last(mapping), activation; kwargs...)
6566
end
6667

src/layers/lux_layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Base.show(io::IO, wf::WrappedFunction) = print(io, "WrappedFunction(", wf.func,
1111
@doc raw"""
1212
StatefulLayer(cell::AbstractReservoirRecurrentCell)
1313
14-
A lightweight wrapper that makes a recurrent cell carry its imput state to the
14+
A lightweight wrapper that makes a recurrent cell carry its input state to the
1515
next step.
1616
1717
## Arguments

src/predict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060

6161
function predict(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st)
6262
T = size(data, 2)
63-
@assert T 1 "data must have at least one time step (columns)."
63+
@assert T1 "data must have at least one time step (columns)."
6464

6565
y1, st = apply(rc, data[:, 1], ps, st)
6666
Y = similar(y1, size(y1, 1), T)

src/reservoircomputer.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ features, and install trained readout weights.
2929
- `(y, st′)` where `y` is the readout output and `st′` contains the updated
3030
states of the reservoir, modifiers, and readout.
3131
"""
32-
@concrete struct ReservoirComputer <: AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)}
32+
@concrete struct ReservoirComputer <:
33+
AbstractReservoirComputer{(:reservoir, :states_modifiers, :readout)}
3334
reservoir::Any
3435
states_modifiers::Any
3536
readout::Any
@@ -72,7 +73,8 @@ function (rc::AbstractReservoirComputer)(inp, ps, st)
7273
return out, merge(new_st, (readout = st_ro,))
7374
end
7475

75-
function collectstates(rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st::NamedTuple)
76+
function collectstates(
77+
rc::AbstractReservoirComputer, data::AbstractMatrix, ps, st::NamedTuple)
7678
newst = st
7779
collected = Any[]
7880
for inp in eachcol(data)
@@ -130,7 +132,8 @@ function is provided, it is called to create a new initial hidden state.
130132
Same as above, but also returns the unchanged `ps` for convenience.
131133
132134
"""
133-
function resetcarry!(rng::AbstractRNG, rc::AbstractReservoirComputer, st; init_carry = nothing)
135+
function resetcarry!(
136+
rng::AbstractRNG, rc::AbstractReservoirComputer, st; init_carry = nothing)
134137
carry = get(st.reservoir, :carry, nothing)
135138
if carry === nothing
136139
outd = rc.reservoir.cell.out_dims

src/states.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

2-
@inline function _apply_tomatrix(states_mod::F, states::AbstractMatrix) where {F <:
3-
Function}
2+
@inline function _apply_tomatrix(
3+
states_mod::F, states::AbstractMatrix) where {F <:
4+
Function}
45
cols = axes(states, 2)
56
states_1 = states_mod(states[:, first(cols)])
67
new_states = similar(states_1, length(states_1), length(cols))
@@ -54,11 +55,12 @@ point with the input that it receives.
5455
esn = ReservoirChain(
5556
Extend(
5657
StatefulLayer(
57-
ESNCell(3 => 300; init_reservoir = rand_sparse(; radius = 1.2, sparsity = 6/300))
58+
ESNCell(
59+
3 => 300; init_reservoir = rand_sparse(; radius = 1.2, sparsity = 6 / 300))
5860
)
5961
),
6062
NLAT2(),
61-
LinearReadout(300+3 => 3)
63+
LinearReadout(300 + 3 => 3)
6264
)
6365
```
6466

0 commit comments

Comments
 (0)