Skip to content

Commit 3ac88f1

Browse files
update CellularAutomata
1 parent 4180c63 commit 3ac88f1

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ RCSparseArraysExt = "SparseArrays"
2626
[compat]
2727
Adapt = "4.1.1"
2828
Aqua = "0.8"
29-
CellularAutomata = "0.0.2"
29+
CellularAutomata = "0.0.6"
3030
Compat = "4.16.0"
3131
DifferentialEquations = "7.15.0"
3232
JET = "0.9.18"

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReservoirComputing
22

33
using Adapt: adapt
4-
using CellularAutomata: CellularAutomaton
4+
using CellularAutomata: CellularAutomaton, AbstractCA
55
using Compat: @compat
66
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
77
using NNlib: fast_act, sigmoid

src/esn/esn_reservoir_drivers.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ specified reservoir driver.
2222
update.
2323
"""
2424
function create_states(reservoir_driver::AbstractReservoirDriver,
25-
train_data::AbstractArray, washout::Int, reservoir_matrix::AbstractMatrix,
26-
input_matrix::AbstractMatrix, bias_vector::AbstractArray)
25+
train_data::AbstractArray{T,2}, washout::Int, reservoir_matrix::AbstractMatrix,
26+
input_matrix::AbstractMatrix, bias_vector::AbstractArray) where {T<:Number}
2727
train_len = size(train_data, 2) - washout
2828
res_size = size(reservoir_matrix, 1)
2929
states = adapt(typeof(train_data), zeros(res_size, train_len))
@@ -32,6 +32,7 @@ function create_states(reservoir_driver::AbstractReservoirDriver,
3232

3333
for i in 1:washout
3434
yv = @view train_data[:, i]
35+
@show typeof(yv)
3536
_state = next_state!(_state, reservoir_driver, _state, yv, reservoir_matrix,
3637
input_matrix, bias_vector, tmp_array)
3738
end
@@ -47,8 +48,8 @@ function create_states(reservoir_driver::AbstractReservoirDriver,
4748
end
4849

4950
function create_states(reservoir_driver::AbstractReservoirDriver,
50-
train_data::AbstractArray, washout::Int, reservoir_matrix::Vector,
51-
input_matrix::AbstractArray, bias_vector::AbstractArray)
51+
train_data::AbstractArray{T,2}, washout::Int, reservoir_matrix::Vector,
52+
input_matrix::AbstractArray, bias_vector::AbstractArray) where {T<:Number}
5253
train_len = size(train_data, 2) - washout
5354
res_size = sum([size(reservoir_matrix[i], 1) for i in 1:length(reservoir_matrix)])
5455
states = adapt(typeof(train_data), zeros(res_size, train_len))
@@ -357,14 +358,19 @@ function obtain_gru_state!(out, variant::FullyGated, gru, x, y, W, W_in, b, tmp_
357358
end
358359

359360
#minimal
361+
#=
360362
function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_array)
361363
mul!(tmp_array[1], gru.Wz_in, y)
362364
mul!(tmp_array[2], gru.Wz, x)
363365
@. tmp_array[3] = gru.activation_function[1](tmp_array[1] + tmp_array[2] + gru.bz)
364366
367+
mul!(tmp_array[4], gru.Wr_in, y)
368+
mul!(tmp_array[5], gru.Wr, x)
369+
@. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + gru.br)
370+
365371
mul!(tmp_array[7], W_in, y)
366372
mul!(tmp_array[8], W, tmp_array[6] .* x)
367-
@. tmp_array[9] = gru.activation_function[2](tmp_array[7] + tmp_array[8] + b)
368-
369-
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
373+
@. tmp_array[9] = gru.activation_function[3](tmp_array[7] + tmp_array[8] + b)
374+
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
370375
end
376+
=#

src/reca/reca.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ arXiv preprint arXiv:1410.0162 (2014).
2525
automata._” arXiv preprint arXiv:1703.02806 (2017).
2626
"""
2727
function RECA(train_data,
28-
automata;
29-
generations=8,
28+
automata::AbstractCA;
29+
generations::Int=8,
3030
input_encoding=RandomMapping(),
3131
nla_type=NLADefault(),
3232
states_type=StandardStates())

src/reca/reca_input_encodings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct RandomMaps{T, E, G, M, S} <: AbstractEncodingData
3636
ca_size::S
3737
end
3838

39-
function create_encoding(rm::RandomMapping, input_data, generations)
39+
function create_encoding(rm::RandomMapping, input_data::Int, generations::Int)
4040
maps = init_maps(size(input_data, 1), rm.permutations, rm.expansion_size)
4141
states_size = generations * rm.expansion_size * rm.permutations
4242
ca_size = rm.expansion_size * rm.permutations

0 commit comments

Comments
 (0)