Skip to content

Commit ccde5df

Browse files
chore: split reca models in ext
1 parent 2bbb28d commit ccde5df

File tree

15 files changed

+118
-114
lines changed

15 files changed

+118
-114
lines changed

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
name = "ReservoirComputing"
22
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
33
authors = ["Francesco Martinuzzi"]
4-
version = "0.11.3"
4+
version = "0.11.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
98
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -14,19 +13,21 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1413
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1514

1615
[weakdeps]
16+
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
1717
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
1818
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
1919
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2020

2121
[extensions]
22+
RCCellularAutomataExt = "CellularAutomata"
2223
RCLIBSVMExt = "LIBSVM"
2324
RCMLJLinearModelsExt = "MLJLinearModels"
2425
RCSparseArraysExt = "SparseArrays"
2526

2627
[compat]
2728
Adapt = "4.1.1"
2829
Aqua = "0.8"
29-
CellularAutomata = "0.0.2"
30+
CellularAutomata = "0.0.6"
3031
Compat = "4.16.0"
3132
DifferentialEquations = "7.16.1"
3233
LIBSVM = "0.8"
@@ -53,4 +54,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5354
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5455

5556
[targets]
56-
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays"]
57+
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "CellularAutomata"]

src/reca/reca_input_encodings.jl renamed to ext/RCCellularAutomataExt.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,49 @@
1-
abstract type AbstractInputEncoding end
2-
abstract type AbstractEncodingData end
1+
module RCCellularAutomataExt
2+
using ReservoirComputing: RECA, RandomMapping, RandomMaps
3+
import ReservoirComputing: train, next_state_prediction!, AbstractOutputLayer, NLADefault,
4+
StandardStates, obtain_prediction
5+
using CellularAutomata
6+
using Random: randperm
7+
8+
function RECA(train_data,
9+
automata;
10+
generations = 8,
11+
input_encoding = RandomMapping(),
12+
nla_type = NLADefault(),
13+
states_type = StandardStates())
14+
in_size = size(train_data, 1)
15+
#res_size = obtain_res_size(input_encoding, generations)
16+
state_encoding = create_encoding(input_encoding, train_data, generations)
17+
states = reca_create_states(state_encoding, automata, train_data)
18+
19+
return RECA(train_data, automata, state_encoding, nla_type, states, states_type)
20+
end
321

4-
struct RandomMapping{I, T} <: AbstractInputEncoding
5-
permutations::I
6-
expansion_size::T
22+
#training dispatch
23+
function train(reca::RECA, target_data, training_method = StandardRidge; kwargs...)
24+
states_new = reca.states_type(reca.nla_type, reca.states, reca.train_data)
25+
return train(training_method, Float32.(states_new), Float32.(target_data); kwargs...)
726
end
827

9-
"""
10-
RandomMapping(permutations, expansion_size)
11-
RandomMapping(permutations; expansion_size=40)
12-
RandomMapping(;permutations=8, expansion_size=40)
28+
#predict dispatch
29+
function (reca::RECA)(prediction,
30+
output_layer::AbstractOutputLayer,
31+
initial_conditions = output_layer.last_value,
32+
last_state = zeros(reca.input_encoding.ca_size))
33+
return obtain_prediction(reca, prediction, last_state, output_layer;
34+
initial_conditions = initial_conditions)
35+
end
1336

14-
Random mapping of the input data directly in the reservoir. The `expansion_size`
15-
determines the dimension of the single reservoir, and `permutations` determines the
16-
number of total reservoirs that will be connected, each with a different mapping.
17-
The detail of this implementation can be found in [1].
37+
function next_state_prediction!(reca::RECA, x, out, i, args...)
38+
rm = reca.input_encoding
39+
x = encoding(rm, out, x)
40+
ca = CellularAutomaton(reca.automata, x, rm.generations + 1)
41+
ca_states = ca.evolution[2:end, :]
42+
x_new = reshape(transpose(ca_states), rm.states_size)
43+
x = ca.evolution[end, :]
44+
return x, x_new
45+
end
1846

19-
[1] Nichele, Stefano, and Andreas Molund. “Deep reservoir computing using cellular
20-
automata.” arXiv preprint arXiv:1703.02806 (2017).
21-
"""
2247
function RandomMapping(; permutations = 8, expansion_size = 40)
2348
RandomMapping(permutations, expansion_size)
2449
end
@@ -27,15 +52,6 @@ function RandomMapping(permutations; expansion_size = 40)
2752
RandomMapping(permutations, expansion_size)
2853
end
2954

30-
struct RandomMaps{T, E, G, M, S} <: AbstractEncodingData
31-
permutations::T
32-
expansion_size::E
33-
generations::G
34-
maps::M
35-
states_size::S
36-
ca_size::S
37-
end
38-
3955
function create_encoding(rm::RandomMapping, input_data, generations)
4056
maps = init_maps(size(input_data, 1), rm.permutations, rm.expansion_size)
4157
states_size = generations * rm.expansion_size * rm.permutations
@@ -70,7 +86,7 @@ function encoding(rm::RandomMaps, input_vector, tot_encoded_vector)
7086
input_vector,
7187
new_tot_enc_vec[((i - 1) * rm.expansion_size + 1):(i * rm.expansion_size)],
7288
rm.maps[i,
73-
:])
89+
:])
7490
end
7591

7692
return new_tot_enc_vec
@@ -105,3 +121,5 @@ function mapping(input_size, mapped_vector_size)
105121
#sample(1:mapped_vector_size, input_size; replace=false)
106122
return randperm(mapped_vector_size)[1:input_size]
107123
end
124+
125+
end #module

src/ReservoirComputing.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module ReservoirComputing
22

33
using Adapt: adapt
4-
using CellularAutomata: CellularAutomaton
54
using Compat: @compat
65
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
76
using NNlib: fast_act, sigmoid
@@ -15,24 +14,19 @@ abstract type AbstractReservoirComputer end
1514
@compat(public, (create_states))
1615

1716
#general
18-
include("states.jl")
19-
include("predict.jl")
20-
21-
#general training
22-
include("train/linear_regression.jl")
23-
17+
include("generics/states.jl")
18+
include("generics/predict.jl")
19+
include("generics/linear_regression.jl")
20+
#extensions
21+
include("extensions/reca.jl")
2422
#esn
25-
include("esn/inits_components.jl")
26-
include("esn/esn_inits.jl")
27-
include("esn/esn_reservoir_drivers.jl")
28-
include("esn/esn.jl")
29-
include("esn/deepesn.jl")
30-
include("esn/hybridesn.jl")
31-
include("esn/esn_predict.jl")
32-
33-
#reca
34-
include("reca/reca.jl")
35-
include("reca/reca_input_encodings.jl")
23+
include("inits/inits_components.jl")
24+
include("inits/esn_inits.jl")
25+
include("layers/esn_reservoir_drivers.jl")
26+
include("models/esn.jl")
27+
include("models/deepesn.jl")
28+
include("models/hybridesn.jl")
29+
include("models/esn_predict.jl")
3630

3731
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
3832
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
@@ -48,8 +42,9 @@ export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
4842
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
4943
export train
5044
export ESN, HybridESN, KnowledgeModel, DeepESN
45+
export Generative, Predictive, OutputLayer
46+
#reca
5147
export RECA
5248
export RandomMapping, RandomMaps
53-
export Generative, Predictive, OutputLayer
5449

5550
end #module

src/extensions/reca.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
abstract type AbstractInputEncoding end
2+
abstract type AbstractEncodingData end
3+
4+
"""
5+
RandomMapping(permutations, expansion_size)
6+
RandomMapping(permutations; expansion_size=40)
7+
RandomMapping(;permutations=8, expansion_size=40)
8+
9+
Random mapping of the input data directly in the reservoir. The `expansion_size`
10+
determines the dimension of the single reservoir, and `permutations` determines the
11+
number of total reservoirs that will be connected, each with a different mapping.
12+
The detail of this implementation can be found in [1].
13+
14+
[1] Nichele, Stefano, and Andreas Molund. “Deep reservoir computing using cellular
15+
automata.” arXiv preprint arXiv:1703.02806 (2017).
16+
"""
17+
struct RandomMapping{I, T} <: AbstractInputEncoding
18+
permutations::I
19+
expansion_size::T
20+
end
21+
22+
struct RandomMaps{T, E, G, M, S} <: AbstractEncodingData
23+
permutations::T
24+
expansion_size::E
25+
generations::G
26+
maps::M
27+
states_size::S
28+
ca_size::S
29+
end
30+
31+
abstract type AbstractReca <: AbstractReservoirComputer end
32+
33+
"""
34+
RECA(train_data,
35+
automata;
36+
generations = 8,
37+
input_encoding=RandomMapping(),
38+
nla_type = NLADefault(),
39+
states_type = StandardStates())
40+
41+
[1] Yilmaz, Ozgur. “_Reservoir computing using cellular automata._”
42+
arXiv preprint arXiv:1410.0162 (2014).
43+
44+
[2] Nichele, Stefano, and Andreas Molund. “_Deep reservoir computing using cellular
45+
automata._” arXiv preprint arXiv:1703.02806 (2017).
46+
"""
47+
struct RECA{S, R, E, T, Q} <: AbstractReca
48+
#res_size::I
49+
train_data::S
50+
automata::R
51+
input_encoding::E
52+
nla_type::ReservoirComputing.NonLinearAlgorithm
53+
states::T
54+
states_type::Q
55+
end
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/esn/esn_inits.jl renamed to src/inits/esn_inits.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,6 @@ function digital_chaotic_adjacency(rng::AbstractRNG, bit_precision::Integer;
944944
end
945945
adjacency_matrix[matrix_order, 1] = 1
946946
for row_index in 1:matrix_order, column_index in 1:matrix_order
947-
948947
if row_index != column_index && rand(rng) < extra_edge_probability
949948
adjacency_matrix[row_index, column_index] = 1
950949
end
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)