Skip to content

Commit 4180c63

Browse files
adding JET and fixes
1 parent 887a441 commit 4180c63

File tree

6 files changed

+46
-47
lines changed

6 files changed

+46
-47
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Aqua = "0.8"
2929
CellularAutomata = "0.0.2"
3030
Compat = "4.16.0"
3131
DifferentialEquations = "7.15.0"
32+
JET = "0.9.18"
3233
LIBSVM = "0.8"
3334
LinearAlgebra = "1.10"
3435
MLJLinearModels = "0.9.2, 0.10"
@@ -46,11 +47,12 @@ julia = "1.10"
4647
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4748
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
4849
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
50+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4951
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
5052
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5153
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5254
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5355
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5456

5557
[targets]
56-
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays"]
58+
test = ["Aqua", "Test", "SafeTestsets", "DifferentialEquations", "MLJLinearModels", "LIBSVM", "Statistics", "SparseArrays", "JET"]

src/ReservoirComputing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using CellularAutomata: CellularAutomaton
55
using Compat: @compat
66
using LinearAlgebra: eigvals, mul!, I, qr, Diagonal
77
using NNlib: fast_act, sigmoid
8-
using Random: Random, AbstractRNG, randperm
8+
using Random: Random, AbstractRNG, randperm, rand
99
using Reexport: Reexport, @reexport
1010
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1111
@reexport using WeightInitializers
@@ -40,7 +40,7 @@ export scaled_rand, weighted_init, informed_init, minimal_init, chebyshev_mappin
4040
logistic_mapping, modified_lm
4141
export rand_sparse, delay_line, delay_line_backward, cycle_jumps,
4242
simple_cycle, pseudo_svd, chaotic_init
43-
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
43+
export RNN, MRNN, GRU, GRUParams, FullyGated
4444
export train
4545
export ESN, HybridESN, KnowledgeModel, DeepESN
4646
export RECA

src/esn/esn_inits.jl

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ Create an input layer for informed echo state networks [^Pathak2018].
146146
Chaos: An Interdisciplinary Journal of Nonlinear Science 28.4 (2018).
147147
"""
148148
function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
149-
scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number}
149+
scaling=T(0.1), model_in_size::Int, gamma=T(0.5)) where {T <: Number}
150150
res_size, in_size = dims
151151
state_size = in_size - model_in_size
152152

@@ -162,18 +162,17 @@ function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
162162
for i in 1:num_for_state
163163
idxs = findall(Bool[zero_connections .== input_matrix[i, :]
164164
for i in 1:size(input_matrix, 1)])
165-
random_row_idx = idxs[DeviceAgnostic.rand(rng, T, 1:end)]
166-
random_clm_idx = range(1, state_size; step=1)[DeviceAgnostic.rand(rng, T, 1:end)]
165+
random_row_idx = idxs[rand(1:length(idxs))]
166+
random_clm_idx = rand(state_size+1:in_size)
167167
input_matrix[random_row_idx, random_clm_idx] = (DeviceAgnostic.rand(rng, T) -
168168
T(0.5)) .* (T(2) * scaling)
169169
end
170170

171171
for i in 1:num_for_model
172172
idxs = findall(Bool[zero_connections .== input_matrix[i, :]
173173
for i in 1:size(input_matrix, 1)])
174-
random_row_idx = idxs[DeviceAgnostic.rand(rng, T, 1:end)]
175-
random_clm_idx = range(state_size + 1, in_size; step=1)[DeviceAgnostic.rand(
176-
rng, T, 1:end)]
174+
random_row_idx = idxs[rand(1:length(idxs))]
175+
random_clm_idx = rand(state_size+1:in_size)
177176
input_matrix[random_row_idx, random_clm_idx] = (DeviceAgnostic.rand(rng, T) -
178177
T(0.5)) .* (T(2) * scaling)
179178
end
@@ -298,7 +297,7 @@ function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
298297
end
299298
end
300299

301-
return T.(input_matrix)
300+
return map(T, input_matrix)
302301
end
303302

304303
@doc raw"""
@@ -888,8 +887,8 @@ end
888887

889888
"""
890889
pseudo_svd([rng], [T], dims...;
891-
max_value=1.0, sparsity=0.1, sorted=true, reverse_sort=false,
892-
return_sparse=false)
890+
max_value=1.0, sparsity=0.1, sorted=true,
891+
return_sparse=false, return_diag=false)
893892
894893
Returns an initializer to build a sparse reservoir matrix with the given
895894
`sparsity` by using a pseudo-SVD approach as described in [^yang].
@@ -910,8 +909,6 @@ Returns an initializer to build a sparse reservoir matrix with the given
910909
Default is 0.1
911910
- `sorted`: A boolean indicating whether to sort the singular values before
912911
creating the diagonal matrix. Default is `true`.
913-
- `reverse_sort`: A boolean indicating whether to reverse the sorted
914-
singular values. Default is `false`.
915912
- `return_sparse`: flag for returning a `sparse` matrix.
916913
Default is `false`.
917914
- `return_diag`: flag for returning a `Diagonal` matrix. If both `return_diag`
@@ -936,13 +933,12 @@ julia> res_matrix = pseudo_svd(5, 5)
936933
"""
937934
function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
938935
max_value::Number=T(1.0), sparsity::Number=0.1, sorted::Bool=true,
939-
reverse_sort::Bool=false, return_sparse::Bool=false,
936+
return_sparse::Bool=false,
940937
return_diag::Bool=false) where {T <: Number}
941938
throw_sparse_error(return_sparse)
942939
reservoir_matrix = create_diag(rng, T, dims[1],
943940
max_value;
944-
sorted=sorted,
945-
reverse_sort=reverse_sort)
941+
sorted=sorted)
946942
tmp_sparsity = get_sparsity(reservoir_matrix, dims[1])
947943

948944
while tmp_sparsity <= sparsity
@@ -960,25 +956,17 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
960956
end
961957
end
962958

963-
#hacky workaround for the moment
964-
function rand_range(rng, T, n::Int)
965-
return Int(1 + floor(DeviceAgnostic.rand(rng, T) * n))
966-
end
967-
968-
function create_diag(rng::AbstractRNG, ::Type{T}, dim::Number, max_value::Number;
969-
sorted::Bool=true, reverse_sort::Bool=false) where {T <: Number}
959+
function create_diag(rng::AbstractRNG, ::Type{T}, dim::Integer, max_value::Number;
960+
sorted::Bool=true) where {T <: Number}
970961
diagonal_matrix = DeviceAgnostic.zeros(rng, T, dim, dim)
971-
if sorted == true
972-
if reverse_sort == true
973-
diagonal_values = sort(
974-
DeviceAgnostic.rand(rng, T, dim) .* max_value; rev=true)
975-
diagonal_values[1] = max_value
976-
else
977-
diagonal_values = sort(DeviceAgnostic.rand(rng, T, dim) .* max_value)
978-
diagonal_values[end] = max_value
979-
end
980-
else
981-
diagonal_values = DeviceAgnostic.rand(rng, T, dim) .* max_value
962+
diagonal_values = Array(DeviceAgnostic.rand(rng, T, dim) .* T(max_value))
963+
if sorted
964+
#if reverse_sort
965+
# Base.sort!(diagonal_values; rev=true)
966+
# diagonal_values[1] = T(max_value)
967+
#else
968+
Base.sort!(diagonal_values)
969+
diagonal_values[end] = T(max_value)
982970
end
983971

984972
for i in 1:dim
@@ -1003,6 +991,11 @@ function create_qmatrix(rng::AbstractRNG, ::Type{T}, dim::Number,
1003991
return qmatrix
1004992
end
1005993

994+
#hacky workaround for the moment
995+
function rand_range(rng, T, n::Int)
996+
return Int(1 + floor(DeviceAgnostic.rand(rng, T) * n))
997+
end
998+
1006999
function get_sparsity(M, dim)
10071000
return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements
10081001
end

src/esn/esn_reservoir_drivers.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,13 @@ end
313313

314314
#check this one, not sure
315315
function create_gru_layers(gru, variant::Minimal, res_size, in_size)
316-
Wz_in = gru.inner_layer(res_size, in_size)
317-
Wz = gru.reservoir(res_size, res_size)
318-
bz = gru.bias(res_size, 1)
316+
Wz_in = gru.inner_layer[1](res_size, in_size)
317+
Wz = gru.reservoir[1](res_size, res_size)
318+
bz = gru.bias[1](res_size, 1)
319319

320-
Wr_in = nothing
321-
Wr = nothing
322-
br = nothing
320+
Wr_in = gru.inner_layer[2](res_size, in_size)
321+
Wr = gru.reservoir[2](res_size, res_size)
322+
br = gru.bias[2](res_size, 1)
323323

324324
return GRUParams(gru.activation_function, variant, Wz_in, Wz, bz, Wr_in, Wr, br)
325325
end
@@ -362,9 +362,9 @@ function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_arr
362362
mul!(tmp_array[2], gru.Wz, x)
363363
@. tmp_array[3] = gru.activation_function[1](tmp_array[1] + tmp_array[2] + gru.bz)
364364

365-
mul!(tmp_array[4], W_in, y)
366-
mul!(tmp_array[5], W, tmp_array[3] .* x)
367-
@. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + b)
365+
mul!(tmp_array[7], W_in, y)
366+
mul!(tmp_array[8], W, tmp_array[6] .* x)
367+
@. tmp_array[9] = gru.activation_function[2](tmp_array[7] + tmp_array[8] + b)
368368

369369
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
370370
end

src/states.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ function PaddedStates(; padding=1.0)
199199
end
200200

201201
function (states_type::PaddedStates)(mat::AbstractMatrix)
202-
results = states_type.(eachcol(mat))
202+
results = map(states_type, eachcol(mat))
203203
return hcat(results...)
204204
end
205205

@@ -294,7 +294,7 @@ nla(nlat::NonLinearAlgorithm, x_old::AbstractVecOrMat) = nlat(x_old)
294294

295295
# dispatch over matrices for all nonlin algorithms
296296
function (nlat::NonLinearAlgorithm)(x_old::AbstractMatrix)
297-
results = nlat.(eachcol(x_old))
297+
results = map(nlat, eachcol(x_old))
298298
return hcat(results...)
299299
end
300300

test/qa.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ReservoirComputing, Aqua
1+
using ReservoirComputing, Aqua, JET
22
@testset "Aqua" begin
33
Aqua.find_persistent_tasks_deps(ReservoirComputing)
44
Aqua.test_ambiguities(ReservoirComputing; recursive=false)
@@ -9,3 +9,7 @@ using ReservoirComputing, Aqua
99
Aqua.test_unbound_args(ReservoirComputing)
1010
Aqua.test_undefined_exports(ReservoirComputing)
1111
end
12+
13+
@testset "JET" begin
14+
JET.test_package(ReservoirComputing)
15+
end

0 commit comments

Comments
 (0)