Skip to content

Commit f77f32a

Browse files
formatting, docstrings
1 parent 966d8b5 commit f77f32a

File tree

8 files changed

+119
-88
lines changed

8 files changed

+119
-88
lines changed

ext/RCLIBSVMExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module RCLIBSVMExt
22
using ReservoirComputing
33
using LIBSVM
44

5-
function ReservoirComputing.train(svr::LIBSVM.AbstractSVR, states, target)
5+
function ReservoirComputing.train(svr::LIBSVM.AbstractSVR,
6+
states::AbstractArray, target::AbstractArray)
67
out_size = size(target, 1)
78
output_matrix = []
89

@@ -17,8 +18,8 @@ function ReservoirComputing.train(svr::LIBSVM.AbstractSVR, states, target)
1718
return OutputLayer(svr, output_matrix, out_size, target[:, end])
1819
end
1920

20-
function ReservoirComputing.get_prediction(
21-
training_method::LIBSVM.AbstractSVR, output_layer, x)
21+
function ReservoirComputing.get_prediction(training_method::LIBSVM.AbstractSVR,
22+
output_layer::AbstractArray, x::AbstractArray)
2223
out = zeros(output_layer.out_size)
2324

2425
for i in 1:(output_layer.out_size)

ext/RCMLJLinearModelsExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ using ReservoirComputing
33
using MLJLinearModels
44

55
function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
6-
states::AbstractArray{T},
7-
target::AbstractArray{T};
6+
states::AbstractArray{T}, target::AbstractArray{T};
87
kwargs...) where {T <: Number}
98
out_size = size(target, 1)
109
output_layer = similar(target, size(target, 1), size(states, 1))

src/esn/deepesn.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ struct DeepESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
1111
states::IS
1212
end
1313

14+
const AbstractDriver = Union{AbstractReservoirDriver, GRU}
15+
1416
"""
1517
DeepESN(train_data, in_size, res_size; kwargs...)
1618
1719
Constructs a Deep Echo State Network (ESN) model for
1820
processing sequential data through a layered architecture of reservoirs.
1921
This constructor allows for the creation of a deep learning model that
2022
benefits from the dynamic memory and temporal processing capabilities of ESNs,
21-
enhanced by the depth provided by multiple reservoir layers. It's particularly
22-
suited for complex sequential tasks where depth can help capture hierarchical
23-
temporal features.
23+
enhanced by the depth provided by multiple reservoir layers.
2424
2525
# Parameters
2626
@@ -62,19 +62,12 @@ train_data = rand(Float32, 3, 100)
6262
deepESN = DeepESN(train_data, 3, 100; depth=3, washout=100)
6363
```
6464
"""
65-
function DeepESN(train_data,
66-
in_size::Int,
67-
res_size::Int;
68-
depth::Int=2,
69-
input_layer=fill(scaled_rand, depth),
70-
bias=fill(zeros32, depth),
71-
reservoir=fill(rand_sparse, depth),
72-
reservoir_driver=RNN(),
73-
nla_type=NLADefault(),
74-
states_type=StandardStates(),
75-
washout::Int=0,
76-
rng=Utils.default_rng(),
77-
matrix_type=typeof(train_data))
65+
function DeepESN(train_data::AbstractArray, in_size::Int, res_size::Int; depth::Int=2,
66+
input_layer=fill(scaled_rand, depth), bias=fill(zeros32, depth),
67+
reservoir=fill(rand_sparse, depth), reservoir_driver::AbstractDriver=RNN(),
68+
nla_type::NonLinearAlgorithm=NLADefault(),
69+
states_type::AbstractStates=StandardStates(), washout::Int=0,
70+
rng::AbstractRNG=Utils.default_rng(), matrix_type=typeof(train_data))
7871
if states_type isa AbstractPaddedStates
7972
in_size = size(train_data, 1) + 1
8073
train_data = vcat(adapt(matrix_type, ones(1, size(train_data, 2))),

src/esn/esn.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ struct ESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
1212
states::IS
1313
end
1414

15+
const AbstractDriver = Union{AbstractReservoirDriver, GRU}
16+
1517
"""
1618
ESN(train_data; kwargs...) -> ESN
1719
@@ -52,17 +54,12 @@ julia> esn = ESN(train_data, 10, 300; washout=10)
5254
ESN(10 => 300)
5355
```
5456
"""
55-
function ESN(train_data,
56-
in_size::Int,
57-
res_size::Int;
58-
input_layer=scaled_rand,
59-
reservoir=rand_sparse,
60-
bias=zeros32,
61-
reservoir_driver=RNN(),
62-
nla_type=NLADefault(),
63-
states_type=StandardStates(),
64-
washout=0,
65-
rng=Utils.default_rng(),
57+
function ESN(train_data::AbstractArray, in_size::Int, res_size::Int;
58+
input_layer=scaled_rand, reservoir=rand_sparse, bias=zeros32,
59+
reservoir_driver::AbstractDriver=RNN(),
60+
nla_type::NonLinearAlgorithm=NLADefault(),
61+
states_type::AbstractStates=StandardStates(),
62+
washout::Int=0, rng::AbstractRNG=Utils.default_rng(),
6663
matrix_type=typeof(train_data))
6764
if states_type isa AbstractPaddedStates
6865
in_size = size(train_data, 1) + 1
@@ -85,11 +82,9 @@ function ESN(train_data,
8582
end
8683

8784
function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction,
88-
output_layer::AbstractOutputLayer;
89-
last_state=esn.states[:, [end]],
85+
output_layer::AbstractOutputLayer; last_state=esn.states[:, [end]],
9086
kwargs...)
9187
pred_len = prediction.prediction_len
92-
9388
return obtain_esn_prediction(esn, prediction, last_state, output_layer;
9489
kwargs...)
9590
end
@@ -133,12 +128,9 @@ julia> output_layer = train(esn, rand(Float32, 3, 90))
133128
OutputLayer successfully trained with output size: 3
134129
```
135130
"""
136-
function train(esn::AbstractEchoStateNetwork,
137-
target_data,
138-
training_method=StandardRidge();
139-
kwargs...)
131+
function train(esn::AbstractEchoStateNetwork, target_data::AbstractArray,
132+
training_method=StandardRidge(); kwargs...)
140133
states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end])
141-
142134
return train(training_method, states_new, target_data; kwargs...)
143135
end
144136

src/esn/esn_reservoir_drivers.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,10 @@ specified reservoir driver.
2222
update.
2323
"""
2424
function create_states(reservoir_driver::AbstractReservoirDriver,
25-
train_data,
26-
washout,
27-
reservoir_matrix,
28-
input_matrix,
29-
bias_vector)
25+
train_data::AbstractArray, washout::Int, reservoir_matrix::AbstractMatrix,
26+
input_matrix::AbstractMatrix, bias_vector::AbstractArray)
3027
train_len = size(train_data, 2) - washout
3128
res_size = size(reservoir_matrix, 1)
32-
3329
states = adapt(typeof(train_data), zeros(res_size, train_len))
3430
tmp_array = allocate_tmp(reservoir_driver, typeof(train_data), res_size)
3531
_state = adapt(typeof(train_data), zeros(res_size, 1))
@@ -51,14 +47,10 @@ function create_states(reservoir_driver::AbstractReservoirDriver,
5147
end
5248

5349
function create_states(reservoir_driver::AbstractReservoirDriver,
54-
train_data,
55-
washout,
56-
reservoir_matrix::Vector,
57-
input_matrix,
58-
bias_vector)
50+
train_data::AbstractArray, washout::Int, reservoir_matrix::Vector,
51+
input_matrix::AbstractArray, bias_vector::AbstractArray)
5952
train_len = size(train_data, 2) - washout
6053
res_size = sum([size(reservoir_matrix[i], 1) for i in 1:length(reservoir_matrix)])
61-
6254
states = adapt(typeof(train_data), zeros(res_size, train_len))
6355
tmp_array = allocate_tmp(reservoir_driver, typeof(train_data), res_size)
6456
_state = adapt(typeof(train_data), zeros(res_size))

src/esn/hybridesn.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ struct HybridESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
1212
states::IS
1313
end
1414

15+
const AbstractDriver = Union{AbstractReservoirDriver, GRU}
16+
1517
struct KnowledgeModel{T, K, O, I, S, D}
1618
prior_model::T
1719
u0::K
@@ -91,19 +93,12 @@ traditional Echo State Networks with a predefined knowledge model [^Pathak2018].
9193
"Hybrid Forecasting of Chaotic Processes:
9294
Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
9395
"""
94-
function HybridESN(model,
95-
train_data,
96-
in_size::Int,
97-
res_size::Int;
98-
input_layer=scaled_rand,
99-
reservoir=rand_sparse,
100-
bias=zeros32,
101-
reservoir_driver=RNN(),
102-
nla_type=NLADefault(),
103-
states_type=StandardStates(),
104-
washout=0,
105-
rng=Utils.default_rng(),
106-
T=Float32,
96+
function HybridESN(model::KnowledgeModel, train_data::AbstractArray,
97+
in_size::Int, res_size::Int; input_layer=scaled_rand, reservoir=rand_sparse,
98+
bias=zeros32, reservoir_driver::AbstractDriver=RNN(),
99+
nla_type::NonLinearAlgorithm=NLADefault(),
100+
states_type::AbstractStates=StandardStates(), washout::Int=0,
101+
rng::AbstractRNG=Utils.default_rng(), T=Float32,
107102
matrix_type=typeof(train_data))
108103
train_data = vcat(train_data, model.model_data[:, 1:(end - 1)])
109104

@@ -130,8 +125,7 @@ function HybridESN(model,
130125
end
131126

132127
function (hesn::HybridESN)(prediction::AbstractPrediction,
133-
output_layer::AbstractOutputLayer;
134-
last_state=hesn.states[:, [end]],
128+
output_layer::AbstractOutputLayer; last_state::AbstractArray=hesn.states[:, [end]],
135129
kwargs...)
136130
km = hesn.model
137131
pred_len = prediction.prediction_len
@@ -148,10 +142,8 @@ function (hesn::HybridESN)(prediction::AbstractPrediction,
148142
kwargs...)
149143
end
150144

151-
function train(hesn::HybridESN,
152-
target_data,
153-
training_method=StandardRidge();
154-
kwargs...)
145+
function train(hesn::HybridESN, target_data::AbstractArray,
146+
training_method=StandardRidge(); kwargs...)
155147
states = vcat(hesn.states, hesn.model.model_data[:, 2:end])
156148
states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end])
157149

src/predict.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,13 @@ The `Predictive` prediction method uses the provided input data
6161
(`prediction_data`) to produce corresponding labels or outputs based
6262
on the learned relationships in the model.
6363
"""
64-
function Predictive(prediction_data)
64+
function Predictive(prediction_data::AbstractArray)
6565
prediction_len = size(prediction_data, 2)
6666
Predictive(prediction_data, prediction_len)
6767
end
6868

69-
function obtain_prediction(rc::AbstractReservoirComputer,
70-
prediction::Generative,
71-
x,
72-
output_layer,
73-
args...;
69+
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative,
70+
x, output_layer::AbstractOutputLayer, args...;
7471
initial_conditions=output_layer.last_value)
7572
#x = last_state
7673
prediction_len = prediction.prediction_len
@@ -88,12 +85,8 @@ function obtain_prediction(rc::AbstractReservoirComputer,
8885
return output
8986
end
9087

91-
function obtain_prediction(rc::AbstractReservoirComputer,
92-
prediction::Predictive,
93-
x,
94-
output_layer,
95-
args...;
96-
kwargs...)
88+
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Predictive,
89+
x, output_layer::AbstractOutputLayer, args...; kwargs...)
9790
prediction_len = prediction.prediction_len
9891
train_method = output_layer.training_method
9992
out_size = output_layer.out_size
@@ -110,7 +103,7 @@ function obtain_prediction(rc::AbstractReservoirComputer,
110103
end
111104

112105
#linear models
113-
function get_prediction(training_method, output_layer, x)
106+
function get_prediction(training_method, output_layer::AbstractOutputLayer, x)
114107
return output_layer.output_matrix * x
115108
end
116109

src/train/linear_regression.jl

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,62 @@
1+
@doc raw"""
2+
3+
StandardRidge([Type], [reg])
4+
5+
Returns a training method for `train` based on ridge regression.
6+
The equations for ridge regression are as follows:
7+
8+
```math
9+
\mathbf{w} = (\mathbf{X}^\top \mathbf{X} +
10+
\lambda \mathbf{I})^{-1} \mathbf{X}^\top \mathbf{y}
11+
```
12+
13+
# Arguments
14+
- `Type`: type of the regularization argument. Default is inferred internally,
15+
there's usually no need to tweak this
16+
- `reg`: regularization coefficient. Default is set to 0.0 (linear regression).
17+
18+
# Examples
19+
```jldoctest
20+
julia> ridge_reg = StandardRidge()
21+
StandardRidge(0.0)
22+
23+
julia> ol = train(ridge_reg, rand(Float32, 10, 10), rand(Float32, 10, 10))
24+
OutputLayer successfully trained with output size: 10
25+
26+
julia> ol.output_matrix #visualize output matrix
27+
10×10 Matrix{Float32}:
28+
0.456574 -0.0407612 0.121963 … 0.859327 -0.127494 0.0572494
29+
0.133216 -0.0337922 0.0185378 0.24077 0.0297829 0.31512
30+
0.379672 -1.24541 -0.444314 1.02269 -0.0446086 0.482282
31+
1.18455 -0.517971 -0.133498 0.84473 0.31575 0.205857
32+
-0.119345 0.563294 0.747992 0.0102919 1.509 -0.328005
33+
-0.0716812 0.0976365 0.628654 … -0.516041 2.4309 -0.113402
34+
0.0153872 -0.52334 0.0526867 0.729326 2.98958 1.32703
35+
0.154027 0.6013 1.05548 -0.0840203 0.991182 -0.328555
36+
1.11007 -0.0371736 -0.0529418 0.186796 -1.21815 0.204838
37+
0.282996 -0.263799 0.132079 0.875417 0.497951 0.273423
38+
39+
julia> ridge_reg = StandardRidge(0.001) #passing a value
40+
StandardRidge(0.001)
41+
42+
julia> ol = train(ridge_reg, rand(Float16, 10, 10), rand(Float16, 10, 10))
43+
OutputLayer successfully trained with output size: 10
44+
45+
julia> ol.output_matrix
46+
10×10 Matrix{Float16}:
47+
-1.251 3.074 -1.566 -0.10297 … 0.3823 1.341 -1.77 -0.445
48+
0.11017 -2.027 0.8975 0.872 -0.643 0.02615 1.083 0.615
49+
0.2634 3.514 -1.168 -1.532 1.486 0.1255 -1.795 -0.06555
50+
0.964 0.9463 -0.006855 -0.519 0.0743 -0.181 -0.433 0.06793
51+
-0.389 1.887 -0.702 -0.8906 0.221 1.303 -1.318 0.2634
52+
-0.1337 -0.4453 -0.06866 0.557 … -0.322 0.247 0.2554 0.5933
53+
-0.6724 0.906 -0.547 0.697 -0.2664 0.809 -0.6836 0.2358
54+
0.8843 -3.664 1.615 1.417 -0.6094 -0.59 1.975 0.4785
55+
1.266 -0.933 0.0664 -0.4497 -0.0759 -0.03897 1.117 0.3152
56+
0.6353 1.327 -0.6978 -1.053 0.8037 0.6577 -0.7246 0.07336
57+
58+
```
59+
"""
160
struct StandardRidge
261
reg::Number
362
end
@@ -10,13 +69,23 @@ function StandardRidge()
1069
return StandardRidge(0.0)
1170
end
1271

13-
function train(sr::StandardRidge,
14-
states,
15-
target_data)
72+
function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray)
1673
#A = states * states' + sr.reg * I
1774
#b = states * target_data
1875
#output_layer = (A \ b)'
19-
output_layer = Matrix(((states * states' + sr.reg * I) \
76+
77+
if size(states, 2) != size(target_data, 2)
78+
throw(DimensionMismatch("\n" *
79+
"\n" *
80+
" - Number of columns in `states`: $(size(states, 2))\n" *
81+
" - Number of columns in `target_data`: $(size(target_data, 2))\n" *
82+
"The dimensions of `states` and `target_data` must align for training." *
83+
"\n"
84+
))
85+
end
86+
87+
T = eltype(states)
88+
output_layer = Matrix(((states * states' + T(sr.reg) * I) \
2089
(states * target_data'))')
2190
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
2291
end

0 commit comments

Comments
 (0)