Skip to content

Commit 9e4711f

Browse files
Merge pull request #239 from SciML/fm/ql
General codebase improvements
2 parents 6b828a4 + 113ba19 commit 9e4711f

File tree

11 files changed

+926
-857
lines changed

11 files changed

+926
-857
lines changed

Project.toml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ version = "0.10.5"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
9-
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
12-
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
1311
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1412
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1613
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1714
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1815

@@ -29,19 +26,17 @@ Adapt = "4.1.1"
2926
Aqua = "0.8"
3027
CellularAutomata = "0.0.2"
3128
DifferentialEquations = "7.15.0"
32-
Distances = "0.10"
3329
LIBSVM = "0.8"
3430
LinearAlgebra = "1.10"
3531
MLJLinearModels = "0.9.2, 0.10"
3632
NNlib = "0.9.26"
37-
PartialFunctions = "1.2"
3833
Random = "1.10"
3934
Reexport = "1.2.2"
4035
SafeTestsets = "0.1"
4136
Statistics = "1.10"
4237
StatsBase = "0.34.4"
4338
Test = "1"
44-
WeightInitializers = "1.0.4"
39+
WeightInitializers = "1.0.5"
4540
julia = "1.10"
4641

4742
[extras]
@@ -51,7 +46,9 @@ LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
5146
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
5247
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5348
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
49+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5450
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5551

5652
[targets]
57-
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"]
53+
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations",
54+
"MLJLinearModels", "LIBSVM", "Statistics"]

src/ReservoirComputing.jl

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1,16 @@
11
module ReservoirComputing
22

3-
using Adapt
4-
using CellularAutomata
5-
using Distances
6-
using LinearAlgebra
7-
using NNlib
8-
using PartialFunctions
9-
using Random
3+
using Adapt: adapt
4+
using CellularAutomata: CellularAutomaton
5+
using LinearAlgebra: eigvals, mul!, I
6+
using NNlib: fast_act, sigmoid
7+
using Random: Random, AbstractRNG
108
using Reexport: Reexport, @reexport
11-
using Statistics
129
using StatsBase: sample
1310
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1411
@reexport using WeightInitializers
1512

16-
#define global types
1713
abstract type AbstractReservoirComputer end
18-
abstract type AbstractOutputLayer end
19-
abstract type AbstractPrediction end
20-
#should probably move some of these
21-
abstract type AbstractGRUVariant end
22-
23-
#general output layer struct
24-
struct OutputLayer{T, I, S, L} <: AbstractOutputLayer
25-
training_method::T
26-
output_matrix::I
27-
out_size::S
28-
last_value::L
29-
end
30-
31-
#prediction types
32-
"""
33-
Generative(prediction_len)
34-
35-
A prediction strategy that enables models to generate autonomous multi-step
36-
forecasts by recursively feeding their own outputs back as inputs for
37-
subsequent prediction steps.
38-
39-
# Parameters
40-
41-
- `prediction_len::Int`: The number of future steps to predict.
42-
43-
# Description
44-
45-
The `Generative` prediction method allows a model to perform multi-step
46-
forecasting by using its own previous predictions as inputs for future predictions.
47-
This approach is especially useful in time series analysis, where each prediction
48-
depends on the preceding data points.
49-
50-
At each step, the model takes the current input, generates a prediction,
51-
and then incorporates that prediction into the input for the next step.
52-
This recursive process continues until the specified
53-
number of prediction steps (`prediction_len`) is reached.
54-
"""
55-
struct Generative{T} <: AbstractPrediction
56-
prediction_len::T
57-
end
58-
59-
struct Predictive{I, T} <: AbstractPrediction
60-
prediction_data::I
61-
prediction_len::T
62-
end
63-
64-
"""
65-
Predictive(prediction_data)
66-
67-
A prediction strategy for supervised learning tasks,
68-
where a model predicts labels based on a provided set
69-
of input features (`prediction_data`).
70-
71-
# Parameters
72-
73-
- `prediction_data`: The input data used for prediction, typically structured as a matrix
74-
where each column represents a sample, and each row represents a feature.
75-
76-
# Description
77-
78-
The `Predictive` prediction method is a standard approach
79-
in supervised machine learning tasks. It uses the provided input data
80-
(`prediction_data`) to produce corresponding labels or outputs based
81-
on the learned relationships in the model. Unlike generative prediction,
82-
this method does not recursively feed predictions into the model;
83-
instead, it operates on fixed input data to produce a single batch of predictions.
84-
85-
This method is suitable for tasks like classification,
86-
regression, or other use cases where the input features
87-
and the number of steps are predefined.
88-
"""
89-
function Predictive(prediction_data)
90-
prediction_len = size(prediction_data, 2)
91-
Predictive(prediction_data, prediction_len)
92-
end
93-
94-
#fallbacks for initializers #eventually to remove once migrated to WeightInitializers.jl
95-
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
96-
:simple_cycle, :pseudo_svd,
97-
:scaled_rand, :weighted_init, :informed_init, :minimal_init)
98-
@eval begin
99-
function ($initializer)(dims::Integer...; kwargs...)
100-
return $initializer(Utils.default_rng(), Float32, dims...; kwargs...)
101-
end
102-
function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
103-
return $initializer(rng, Float32, dims...; kwargs...)
104-
end
105-
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: Number}
106-
return $initializer(Utils.default_rng(), T, dims...; kwargs...)
107-
end
108-
109-
# Partial application
110-
function ($initializer)(rng::AbstractRNG; kwargs...)
111-
return PartialFunction.Partial{Nothing}($initializer, rng, kwargs)
112-
end
113-
function ($initializer)(::Type{T}; kwargs...) where {T <: Number}
114-
return PartialFunction.Partial{T}($initializer, nothing, kwargs)
115-
end
116-
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number}
117-
return PartialFunction.Partial{T}($initializer, rng, kwargs)
118-
end
119-
function ($initializer)(; kwargs...)
120-
return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs)
121-
end
122-
end
123-
end
12414

12515
#general
12616
include("states.jl")
@@ -130,8 +20,7 @@ include("predict.jl")
13020
include("train/linear_regression.jl")
13121

13222
#esn
133-
include("esn/esn_input_layers.jl")
134-
include("esn/esn_reservoirs.jl")
23+
include("esn/esn_inits.jl")
13524
include("esn/esn_reservoir_drivers.jl")
13625
include("esn/esn.jl")
13726
include("esn/deepesn.jl")
@@ -155,9 +44,7 @@ export scaled_rand, weighted_init, informed_init, minimal_init
15544
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd
15645
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
15746
export train
158-
export ESN
159-
export HybridESN, KnowledgeModel
160-
export DeepESN
47+
export ESN, HybridESN, KnowledgeModel, DeepESN
16148
export RECA, sample
16249
export RandomMapping, RandomMaps
16350
export Generative, Predictive, OutputLayer

src/esn/deepesn.jl

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,22 @@ temporal features.
4444
Default is an RNN model.
4545
- `nla_type`: The type of non-linear activation used in the reservoir.
4646
Default is `NLADefault()`.
47-
- `states_type`: Defines the type of states used in the ESN (e.g., standard states).
48-
Default is `StandardStates()`.
49-
- `washout`: The number of initial timesteps to be discarded in the ESN's training phase.
50-
Default is 0.
51-
- `rng`: Random number generator used for initializing weights. Default is the package's
52-
default random number generator.
47+
- `states_type`: Defines the type of states used in the ESN
48+
(e.g., standard states). Default is `StandardStates()`.
49+
- `washout`: The number of initial timesteps to be discarded
50+
in the ESN's training phase. Default is 0.
51+
- `rng`: Random number generator used for initializing weights.
52+
Default is `Utils.default_rng()`.
5353
- `matrix_type`: The type of matrix used for storing the training data.
5454
Default is inferred from `train_data`.
5555
5656
# Example
5757
5858
```julia
59-
# Prepare your training data
60-
train_data = [your_training_data_here]
59+
train_data = rand(Float32, 3, 100)
6160
6261
# Create a DeepESN with specific parameters
63-
deepESN = DeepESN(train_data, 10, 100; depth=3, washout=100)
64-
65-
# Proceed with training and prediction (pseudocode)
66-
train(deepESN, target_data)
67-
prediction = predict(deepESN, new_data)
62+
deepESN = DeepESN(train_data, 3, 100; depth=3, washout=100)
6863
```
6964
"""
7065
function DeepESN(train_data,
@@ -82,7 +77,7 @@ function DeepESN(train_data,
8277
matrix_type=typeof(train_data))
8378
if states_type isa AbstractPaddedStates
8479
in_size = size(train_data, 1) + 1
85-
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
80+
train_data = vcat(adapt(matrix_type, ones(1, size(train_data, 2))),
8681
train_data)
8782
end
8883

src/esn/esn.jl

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,41 @@ end
1515
"""
1616
ESN(train_data; kwargs...) -> ESN
1717
18-
Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks.
18+
Creates an Echo State Network (ESN).
1919
20-
# Parameters
20+
# Arguments
2121
22-
- `train_data`: Matrix of training data (columns as time steps, rows as features).
22+
- `train_data`: Matrix of training data `num_features x time_steps`.
2323
- `variation`: Variation of ESN (default: `Default()`).
2424
- `input_layer`: Input layer of ESN.
2525
- `reservoir`: Reservoir of the ESN.
2626
- `bias`: Bias vector for each time step.
27+
- `rng`: Random number generator used for initializing weights.
28+
Default is `Utils.default_rng()`.
2729
- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`).
2830
- `nla_type`: Non-linear activation type (default: `NLADefault()`).
2931
- `states_type`: Format for storing states (default: `StandardStates()`).
3032
- `washout`: Initial time steps to discard (default: `0`).
3133
- `matrix_type`: Type of matrices used internally (default: type of `train_data`).
3234
33-
# Returns
34-
35-
- An initialized ESN instance with specified parameters.
36-
3735
# Examples
3836
39-
```julia
40-
using ReservoirComputing
41-
42-
train_data = rand(10, 100) # 10 features, 100 time steps
43-
44-
esn = ESN(train_data; reservoir=RandSparseReservoir(200), washout=10)
37+
```jldoctest
38+
julia> train_data = rand(Float32, 10, 100) # 10 features, 100 time steps
39+
10×100 Matrix{Float32}:
40+
0.567676 0.154756 0.584611 0.294015 … 0.573946 0.894333 0.429133
41+
0.327073 0.729521 0.804667 0.263944 0.559342 0.020167 0.897862
42+
0.453606 0.800058 0.568311 0.749441 0.0713146 0.464795 0.532854
43+
0.0173253 0.536959 0.722116 0.910328 0.00224048 0.00202501 0.631075
44+
0.366744 0.119761 0.100593 0.125122 0.700562 0.675474 0.102947
45+
0.539737 0.768351 0.54681 0.648672 … 0.256738 0.223784 0.94327
46+
0.558099 0.42676 0.1948 0.735625 0.0989234 0.119342 0.624182
47+
0.0603135 0.929999 0.263439 0.0372732 0.066125 0.332769 0.25562
48+
0.4463 0.334423 0.444679 0.311695 0.0494497 0.27171 0.214925
49+
0.987182 0.898593 0.295241 0.233098 0.789699 0.453692 0.759205
50+
51+
julia> esn = ESN(train_data, 10, 300; washout=10)
52+
ESN(10 => 300)
4553
```
4654
"""
4755
function ESN(train_data,
@@ -58,7 +66,7 @@ function ESN(train_data,
5866
matrix_type=typeof(train_data))
5967
if states_type isa AbstractPaddedStates
6068
in_size = size(train_data, 1) + 1
61-
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
69+
train_data = vcat(adapt(matrix_type, ones(1, size(train_data, 2))),
6270
train_data)
6371
end
6472

@@ -86,6 +94,10 @@ function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction,
8694
kwargs...)
8795
end
8896

97+
function Base.show(io::IO, esn::ESN)
98+
print(io, "ESN(", size(esn.train_data, 1), " => ", size(esn.reservoir_matrix, 1), ")")
99+
end
100+
89101
#training dispatch on esn
90102
"""
91103
train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0))
@@ -98,27 +110,27 @@ Trains an Echo State Network (ESN) using the provided target data and a specifie
98110
- `target_data`: Supervised training data for the ESN.
99111
- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`).
100112
101-
# Returns
102-
103-
- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation.
104-
105-
# Returns
106-
107-
The trained ESN model. The exact type and structure of the return value depends on the
108-
`training_method` and the specific ESN implementation.
109-
110-
```julia
111-
using ReservoirComputing
112-
113-
# Initialize an ESN instance and target data
114-
esn = ESN(train_data; reservoir=RandSparseReservoir(200), washout=10)
115-
target_data = rand(size(train_data, 2))
116-
117-
# Train the ESN using the default training method
118-
trained_esn = train(esn, target_data)
119-
120-
# Train the ESN using a custom training method
121-
trained_esn = train(esn, target_data; training_method=StandardRidge(1.0))
113+
# Example
114+
115+
```jldoctest
116+
julia> train_data = rand(Float32, 10, 100) # 10 features, 100 time steps
117+
10×100 Matrix{Float32}:
118+
0.11437 0.425367 0.585867 0.34078 … 0.0531493 0.761425 0.883164
119+
0.301373 0.497806 0.279603 0.802417 0.49873 0.270156 0.333333
120+
0.135224 0.660179 0.394233 0.512753 0.901221 0.784377 0.687691
121+
0.510203 0.877234 0.614245 0.978405 0.332775 0.768826 0.527077
122+
0.955027 0.398322 0.312156 0.981938 0.473357 0.156704 0.476101
123+
0.353024 0.997632 0.164328 0.470783 … 0.745613 0.85797 0.465201
124+
0.966044 0.194299 0.599167 0.040475 0.0996013 0.325959 0.770103
125+
0.292068 0.495138 0.481299 0.214566 0.819573 0.155951 0.227168
126+
0.133498 0.451058 0.0761995 0.90421 0.994212 0.332164 0.545112
127+
0.214467 0.791524 0.124105 0.951805 0.947166 0.954244 0.889733
128+
129+
julia> esn = ESN(train_data, 10, 300; washout=10)
130+
ESN(10 => 300)
131+
132+
julia> output_layer = train(esn, rand(Float32, 3, 90))
133+
OutputLayer successfully trained with output size: 3
122134
```
123135
"""
124136
function train(esn::AbstractEchoStateNetwork,

0 commit comments

Comments
 (0)