Skip to content

Commit 5d88459

Browse files
adding sparse option to initis
1 parent ae2b601 commit 5d88459

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

.github/workflows/Downgrade.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
version: ${{ matrix.version }}
3434
- uses: julia-actions/julia-downgrade-compat@v1
3535
with:
36-
skip: Pkg, TOML, Test, Random, LinearAlgebra, Statistics
36+
skip: Pkg, TOML, Test, Random, LinearAlgebra, Statistics, SparseArrays
3737
- uses: julia-actions/cache@v2
3838
with:
3939
token: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
14+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1415
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1516
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1617

@@ -35,6 +36,7 @@ NNlib = "0.9.26"
3536
Random = "1.10"
3637
Reexport = "1.2.2"
3738
SafeTestsets = "0.1"
39+
SparseArrays = "1.10"
3840
Statistics = "1.10"
3941
StatsBase = "0.34.4"
4042
Test = "1"

src/ReservoirComputing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using LinearAlgebra: eigvals, mul!, I
77
using NNlib: fast_act, sigmoid
88
using Random: Random, AbstractRNG
99
using Reexport: Reexport, @reexport
10+
using SparseArrays: sparse
1011
using StatsBase: sample
1112
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
1213
@reexport using WeightInitializers

src/esn/esn_inits.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,17 @@ julia> res_matrix = rand_sparse(5, 5; sparsity=0.5)
315315
```
316316
"""
317317
function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
318-
radius=T(1.0), sparsity=T(0.1), std=T(1.0)) where {T <: Number}
318+
radius=T(1.0), sparsity=T(0.1), std=T(1.0),
319+
return_sparse::Bool=false) where {T <: Number}
319320
lcl_sparsity = T(1) - sparsity #consistency with current implementations
320321
reservoir_matrix = sparse_init(rng, T, dims...; sparsity=lcl_sparsity, std=std)
321322
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
322323
reservoir_matrix .*= radius / rho_w
323324
if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
324325
error("Sparsity too low for size of the matrix. Increase res_size or increase sparsity")
325326
end
326-
return reservoir_matrix
327+
328+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
327329
end
328330

329331
"""
@@ -366,7 +368,7 @@ julia> res_matrix = delay_line(5, 5; weight=1)
366368
IEEE transactions on neural networks 22.1 (2010): 131-144.
367369
"""
368370
function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
369-
weight=T(0.1)) where {T <: Number}
371+
weight=T(0.1), return_sparse::Bool=false) where {T <: Number}
370372
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
371373
@assert length(dims) == 2&&dims[1] == dims[2] "The dimensions
372374
must define a square matrix (e.g., (100, 100))"
@@ -375,7 +377,7 @@ function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
375377
reservoir_matrix[i + 1, i] = weight
376378
end
377379

378-
return reservoir_matrix
380+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
379381
end
380382

381383
"""
@@ -421,7 +423,7 @@ julia> res_matrix = delay_line_backward(Float16, 5, 5)
421423
IEEE transactions on neural networks 22.1 (2010): 131-144.
422424
"""
423425
function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
424-
weight=T(0.1), fb_weight=T(0.2)) where {T <: Number}
426+
weight=T(0.1), fb_weight=T(0.2), return_sparse::Bool=false) where {T <: Number}
425427
res_size = first(dims)
426428
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
427429

@@ -430,7 +432,7 @@ function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
430432
reservoir_matrix[i, i + 1] = fb_weight
431433
end
432434

433-
return reservoir_matrix
435+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
434436
end
435437

436438
"""
@@ -479,7 +481,7 @@ julia> res_matrix = cycle_jumps(5, 5; jump_size=2)
479481
"""
480482
function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
481483
cycle_weight::Number=T(0.1), jump_weight::Number=T(0.1),
482-
jump_size::Int=3) where {T <: Number}
484+
jump_size::Int=3, return_sparse::Bool=false) where {T <: Number}
483485
res_size = first(dims)
484486
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
485487

@@ -498,7 +500,7 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
498500
reservoir_matrix[tmp, i] = jump_weight
499501
end
500502

501-
return reservoir_matrix
503+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
502504
end
503505

504506
"""
@@ -540,15 +542,15 @@ julia> res_matrix = simple_cycle(5, 5; weight=11)
540542
IEEE transactions on neural networks 22.1 (2010): 131-144.
541543
"""
542544
function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
543-
weight=T(0.1)) where {T <: Number}
545+
weight=T(0.1), return_sparse::Bool=false) where {T <: Number}
544546
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
545547

546548
for i in 1:(dims[1] - 1)
547549
reservoir_matrix[i + 1, i] = weight
548550
end
549551

550552
reservoir_matrix[1, dims[1]] = weight
551-
return reservoir_matrix
553+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
552554
end
553555

554556
"""
@@ -590,7 +592,7 @@ julia> res_matrix = pseudo_svd(5, 5)
590592
"""
591593
function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
592594
max_value::Number=T(1.0), sparsity::Number=0.1, sorted::Bool=true,
593-
reverse_sort::Bool=false) where {T <: Number}
595+
reverse_sort::Bool=false, return_sparse::Bool=false) where {T <: Number}
594596
reservoir_matrix = create_diag(rng, T, dims[1],
595597
max_value;
596598
sorted=sorted,
@@ -605,7 +607,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
605607
tmp_sparsity = get_sparsity(reservoir_matrix, dims[1])
606608
end
607609

608-
return reservoir_matrix
610+
return return_sparse ? sparse(reservoir_matrix) : reservoir_matrix
609611
end
610612

611613
#hacky workaround for the moment

0 commit comments

Comments
 (0)