Skip to content

Commit e011c24

Browse files
generalized minimal building blocks
1 parent a268310 commit e011c24

File tree

4 files changed

+175
-145
lines changed

4 files changed

+175
-145
lines changed

src/ReservoirComputing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include("predict.jl")
2222
include("train/linear_regression.jl")
2323

2424
#esn
25+
include("esn/inits_components.jl")
2526
include("esn/esn_inits.jl")
2627
include("esn/esn_reservoir_drivers.jl")
2728
include("esn/esn.jl")

src/esn/esn_inits.jl

Lines changed: 32 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,3 @@
1-
# dispatch over dense inits
2-
function return_init_as(::Val{false}, layer_matrix::AbstractVecOrMat)
3-
return layer_matrix
4-
end
5-
6-
# error for sparse inits with no SparseArrays.jl call
7-
8-
function throw_sparse_error(return_sparse::Bool)
9-
if return_sparse && !haskey(Base.loaded_modules, :SparseArrays)
10-
error("""\n
11-
Sparse output requested but SparseArrays.jl is not loaded.
12-
Please load it with:
13-
14-
using SparseArrays\n
15-
""")
16-
end
17-
end
18-
19-
## scale spectral radius
20-
21-
function scale_radius!(reservoir_matrix::AbstractMatrix, radius::AbstractFloat)
22-
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
23-
reservoir_matrix .*= radius / rho_w
24-
if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
25-
error("""\n
26-
Sparsity too low for size of the matrix.
27-
Increase res_size or increase sparsity.\n
28-
""")
29-
end
30-
return reservoir_matrix
31-
end
32-
331
### input layers
342
"""
353
scaled_rand([rng], [T], dims...;
@@ -274,51 +242,16 @@ julia> res_input = minimal_init(8, 3; p=0.8)# higher p -> more positive signs
274242
IEEE transactions on neural networks 22.1 (2010): 131-144.
275243
"""
276244
function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
277-
sampling_type::Symbol=:bernoulli, kwargs...) where {T <: Number}
245+
weight::Number=T(0.1), sampling_type::Symbol=:bernoulli!, kwargs...) where {T <:
246+
Number}
278247
res_size, in_size = dims
279-
f_sample = getfield(@__MODULE__, sampling_type)
280-
layer_matrix = f_sample(rng, T, res_size, in_size; kwargs...)
281-
return layer_matrix
282-
end
283-
284-
function bernoulli(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
285-
weight::Number=T(0.1), p::Number=T(0.5)) where {T <: Number}
286248
input_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
287-
for idx in 1:res_size
288-
for jdx in 1:in_size
289-
if DeviceAgnostic.rand(rng, T) < p
290-
input_matrix[idx, jdx] = T(weight)
291-
else
292-
input_matrix[idx, jdx] = -T(weight)
293-
end
294-
end
295-
end
249+
input_matrix .+= T(weight)
250+
f_sample = getfield(@__MODULE__, sampling_type)
251+
f_sample(rng, input_matrix; kwargs...)
296252
return input_matrix
297253
end
298254

299-
function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
300-
irrational::Irrational=pi, start::Int=1,
301-
weight::Number=T(0.1)) where {T <: Number}
302-
setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1))))
303-
ir_string = string(BigFloat(irrational)) |> collect
304-
deleteat!(ir_string, findall(x -> x == '.', ir_string))
305-
ir_array = DeviceAgnostic.zeros(rng, T, length(ir_string))
306-
input_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
307-
308-
for idx in eachindex(ir_string)
309-
ir_array[idx] = parse(Int, ir_string[idx])
310-
end
311-
312-
for idx in 1:res_size
313-
for jdx in 1:in_size
314-
random_number = DeviceAgnostic.rand(rng, T)
315-
input_matrix[idx, jdx] = random_number < 0.5 ? -T(weight) : T(weight)
316-
end
317-
end
318-
319-
return T.(input_matrix)
320-
end
321-
322255
@doc raw"""
323256
chebyshev_mapping([rng], [T], dims...;
324257
amplitude=one(T), sine_divisor=one(T),
@@ -689,30 +622,18 @@ julia> res_matrix = delay_line(5, 5; weight=1)
689622
IEEE transactions on neural networks 22.1 (2010): 131-144.
690623
"""
691624
function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
692-
weight=T(0.1), shift::Int=1, return_sparse::Bool=false) where {T <: Number}
625+
weight=T(0.1), shift::Int=1, return_sparse::Bool=false,
626+
kwargs...) where {T <: Number}
693627
throw_sparse_error(return_sparse)
694628
@assert length(dims) == 2&&dims[1] == dims[2] """\n
695629
The dimensions must define a square matrix
696630
(e.g., (100, 100))
697631
"""
698632
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
699-
delay_line!(reservoir_matrix, weight, shift)
633+
delay_line!(rng, reservoir_matrix, weight, shift; kwargs...)
700634
return return_init_as(Val(return_sparse), reservoir_matrix)
701635
end
702636

703-
function delay_line!(reservoir_matrix::AbstractMatrix, weight::Number,
704-
shift::Int)
705-
weights = fill(weight, size(reservoir_matrix, 1) - shift)
706-
delay_line!(reservoir_matrix, weights, shift)
707-
end
708-
709-
function delay_line!(reservoir_matrix::AbstractMatrix, weight::AbstractVector,
710-
shift::Int)
711-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - shift)
712-
reservoir_matrix[idx + shift, idx] = weight[idx]
713-
end
714-
end
715-
716637
"""
717638
delay_line_backward([rng], [T], dims...;
718639
weight=0.1, fb_weight=0.2, return_sparse=false)
@@ -763,27 +684,15 @@ julia> res_matrix = delay_line_backward(Float16, 5, 5)
763684
"""
764685
function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
765686
weight=T(0.1), fb_weight=T(0.2), shift::Int=1, fb_shift::Int=1,
766-
return_sparse::Bool=false) where {T <: Number}
687+
return_sparse::Bool=false, delay_kwargs::NamedTuple=NamedTuple(),
688+
fb_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
767689
throw_sparse_error(return_sparse)
768690
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
769-
delay_line!(reservoir_matrix, weight, shift)
770-
backward_connection!(reservoir_matrix, fb_weight, fb_shift)
691+
delay_line!(rng, reservoir_matrix, weight, shift; delay_kwargs...)
692+
backward_connection!(rng, reservoir_matrix, fb_weight, fb_shift; fb_kwargs...)
771693
return return_init_as(Val(return_sparse), reservoir_matrix)
772694
end
773695

774-
function backward_connection!(reservoir_matrix::AbstractMatrix, weight::Number,
775-
shift::Int)
776-
weights = fill(weight, size(reservoir_matrix, 1) - shift)
777-
backward_connection!(reservoir_matrix, weights, shift)
778-
end
779-
780-
function backward_connection!(reservoir_matrix::AbstractMatrix, weight::AbstractVector,
781-
shift::Int)
782-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - shift)
783-
reservoir_matrix[idx, idx + shift] = weight[idx]
784-
end
785-
end
786-
787696
"""
788697
cycle_jumps([rng], [T], dims...;
789698
cycle_weight=0.1, jump_weight=0.1, jump_size=3, return_sparse=false)
@@ -835,26 +744,17 @@ julia> res_matrix = cycle_jumps(5, 5; jump_size=2)
835744
"""
836745
function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
837746
cycle_weight::Number=T(0.1), jump_weight::Number=T(0.1),
838-
jump_size::Int=3, return_sparse::Bool=false) where {T <: Number}
747+
jump_size::Int=3, return_sparse::Bool=false,
748+
cycle_kwargs::NamedTuple=NamedTuple(), jump_kwargs::NamedTuple=NamedTuple()) where {T <:
749+
Number}
839750
throw_sparse_error(return_sparse)
840751
res_size = first(dims)
841752
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
842-
simple_cycle!(reservoir_matrix, cycle_weight)
843-
add_jumps!(reservoir_matrix, cycle_weight, jump_size)
753+
simple_cycle!(rng, reservoir_matrix, cycle_weight; cycle_kwargs...)
754+
add_jumps!(rng, reservoir_matrix, cycle_weight, jump_size; jump_kwargs...)
844755
return return_init_as(Val(return_sparse), reservoir_matrix)
845756
end
846757

847-
function add_jumps!(reservoir_matrix::AbstractMatrix, weight::Number, jump_size::Int)
848-
for idx in 1:jump_size:(size(reservoir_matrix, 1) - jump_size)
849-
tmp = (idx + jump_size) % size(reservoir_matrix, 1)
850-
if tmp == 0
851-
tmp = size(reservoir_matrix, 1)
852-
end
853-
reservoir_matrix[idx, tmp] = weight
854-
reservoir_matrix[tmp, idx] = weight
855-
end
856-
end
857-
858758
"""
859759
simple_cycle([rng], [T], dims...;
860760
weight=0.1, return_sparse=false)
@@ -900,25 +800,13 @@ julia> res_matrix = simple_cycle(5, 5; weight=11)
900800
IEEE transactions on neural networks 22.1 (2010): 131-144.
901801
"""
902802
function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
903-
weight=T(0.1), return_sparse::Bool=false) where {T <: Number}
803+
weight=T(0.1), return_sparse::Bool=false, kwargs...) where {T <: Number}
904804
throw_sparse_error(return_sparse)
905805
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
906-
simple_cycle!(reservoir_matrix, weight)
806+
simple_cycle!(rng, reservoir_matrix, weight; kwargs...)
907807
return return_init_as(Val(return_sparse), reservoir_matrix)
908808
end
909809

910-
function simple_cycle!(reservoir_matrix::AbstractMatrix, weight::Number)
911-
weights = fill(weight, size(reservoir_matrix, 1))
912-
simple_cycle!(reservoir_matrix, weights)
913-
end
914-
915-
function simple_cycle!(reservoir_matrix::AbstractMatrix, weight::AbstractVector)
916-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
917-
reservoir_matrix[idx + 1, idx] = weight[idx]
918-
end
919-
reservoir_matrix[1, end] = weight[end]
920-
end
921-
922810
"""
923811
pseudo_svd([rng], [T], dims...;
924812
max_value=1.0, sparsity=0.1, sorted=true, reverse_sort=false,
@@ -1524,12 +1412,14 @@ julia> reservoir_matrix = selfloop_delayline_backward(5, 5; weight=0.3)
15241412
"""
15251413
function selfloop_delayline_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15261414
shift::Int=1, fb_shift::Int=2, weight=T(0.1f0), fb_weight=weight,
1527-
selfloop_weight=T(0.1f0), return_sparse::Bool=false) where {T <: Number}
1415+
selfloop_weight=T(0.1f0), return_sparse::Bool=false,
1416+
delay_kwargs::NamedTuple=NamedTuple(), fb_kwargs::NamedTuple=NamedTuple()) where {T <:
1417+
Number}
15281418
throw_sparse_error(return_sparse)
15291419
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
15301420
reservoir_matrix += T(selfloop_weight) .* I(dims[1])
1531-
delay_line!(reservoir_matrix, weight, shift)
1532-
backward_connection!(reservoir_matrix, fb_weight, fb_shift)
1421+
delay_line!(rng, reservoir_matrix, weight, shift; delay_kwargs...)
1422+
backward_connection!(rng, reservoir_matrix, fb_weight, fb_shift; fb_kwargs...)
15331423
return return_init_as(Val(return_sparse), reservoir_matrix)
15341424
end
15351425

@@ -1596,11 +1486,11 @@ julia> reservoir_matrix = selfloop_forward_connection(5, 5; weight=0.5)
15961486
"""
15971487
function selfloop_forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15981488
weight=T(0.1f0), selfloop_weight=T(0.1f0), shift::Int=2,
1599-
return_sparse::Bool=false) where {T <: Number}
1489+
return_sparse::Bool=false, kwargs...) where {T <: Number}
16001490
throw_sparse_error(return_sparse)
16011491
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
16021492
reservoir_matrix += T(selfloop_weight) .* I(dims[1])
1603-
delay_line!(reservoir_matrix, weight, shift)
1493+
delay_line!(rng, reservoir_matrix, weight, shift; kwargs...)
16041494
return return_init_as(Val(return_sparse), reservoir_matrix)
16051495
end
16061496

@@ -1662,21 +1552,20 @@ julia> reservoir_matrix = forward_connection(5, 5; weight=0.5)
16621552
International Journal of Computational Science and Engineering 19.3 (2019): 407-417.
16631553
"""
16641554
function forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer...;
1665-
weight=T(0.1f0), return_sparse::Bool=false) where {T <: Number}
1555+
weight=T(0.1f0), return_sparse::Bool=false, kwargs...) where {T <: Number}
16661556
throw_sparse_error(return_sparse)
16671557
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
1668-
delay_line!(reservoir_matrix, weight, 2)
1558+
delay_line!(rng, reservoir_matrix, weight, 2; kwargs...)
16691559
return return_init_as(Val(return_sparse), reservoir_matrix)
16701560
end
16711561

16721562
### fallbacks
16731563
#fallbacks for initializers #eventually to remove once migrated to WeightInitializers.jl
16741564
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
1675-
:simple_cycle, :pseudo_svd, :chaotic_init,
1676-
:scaled_rand, :weighted_init, :informed_init, :minimal_init, :chebyshev_mapping,
1677-
:logistic_mapping, :modified_lm, :low_connectivity, :double_cycle, :selfloop_cycle,
1678-
:selfloop_feedback_cycle, :selfloop_delayline_backward, :selfloop_forward_connection,
1679-
:forward_connection)
1565+
:simple_cycle, :pseudo_svd, :chaotic_init, :scaled_rand, :weighted_init,
1566+
:informed_init, :minimal_init, :chebyshev_mapping, :logistic_mapping, :modified_lm,
1567+
:low_connectivity, :double_cycle, :selfloop_cycle, :selfloop_feedback_cycle,
1568+
:selfloop_delayline_backward, :selfloop_forward_connection, :forward_connection)
16801569
@eval begin
16811570
function ($initializer)(dims::Integer...; kwargs...)
16821571
return $initializer(Utils.default_rng(), Float32, dims...; kwargs...)

0 commit comments

Comments
 (0)