Skip to content

Commit 10befc8

Browse files
feat: add uniform dimension check to reservoir inits
feat: add uniform dimension check to reservoir inits
2 parents 934acff + ce480d5 commit 10befc8

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

src/esn/esn_inits.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,7 @@ function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
8989
throw_sparse_error(return_sparse)
9090
approx_res_size, in_size = dims
9191
res_size = Int(floor(approx_res_size / in_size) * in_size)
92-
if res_size != approx_res_size
93-
@warn """Reservoir size has changed!\n
94-
Computed reservoir size ($res_size) does not equal the \
95-
provided reservoir size ($approx_res_size). \n
96-
Using computed value ($res_size). Make sure to modify the \
97-
reservoir initializer accordingly. \n
98-
"""
99-
end
92+
check_modified_ressize(res_size, approx_res_size)
10093
layer_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
10194
q = floor(Int, res_size / in_size)
10295

@@ -207,14 +200,7 @@ function weighted_minimal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
207200
throw_sparse_error(return_sparse)
208201
approx_res_size, in_size = dims
209202
res_size = Int(floor(approx_res_size / in_size) * in_size)
210-
if res_size != approx_res_size
211-
@warn """Reservoir size has changed!\n
212-
Computed reservoir size ($res_size) does not equal the \
213-
provided reservoir size ($approx_res_size). \n
214-
Using computed value ($res_size). Make sure to modify the \
215-
reservoir initializer accordingly. \n
216-
"""
217-
end
203+
check_modified_ressize(res_size, approx_res_size)
218204
layer_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
219205
q = floor(Int, res_size / in_size)
220206

@@ -705,6 +691,7 @@ function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
705691
radius::Number=T(1.0), sparsity::Number=T(0.1), std::Number=T(1.0),
706692
return_sparse::Bool=false) where {T <: Number}
707693
throw_sparse_error(return_sparse)
694+
check_res_size(dims...)
708695
lcl_sparsity = T(1) - sparsity #consistency with current implementations
709696
reservoir_matrix = sparse_init(rng, T, dims...; sparsity=lcl_sparsity, std=std)
710697
reservoir_matrix = scale_radius!(reservoir_matrix, T(radius))
@@ -764,6 +751,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
764751
reverse_sort::Bool=false, return_sparse::Bool=false,
765752
return_diag::Bool=false) where {T <: Number}
766753
throw_sparse_error(return_sparse)
754+
check_res_size(dims...)
767755
reservoir_matrix = create_diag(rng, T, dims[1],
768756
T(max_value);
769757
sorted=sorted,
@@ -888,6 +876,7 @@ function chaotic_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
888876
extra_edge_probability::AbstractFloat=T(0.1), spectral_radius::AbstractFloat=one(T),
889877
return_sparse::Bool=false) where {T <: Number}
890878
throw_sparse_error(return_sparse)
879+
check_res_size(dims...)
891880
requested_order = first(dims)
892881
if length(dims) > 1 && dims[2] != requested_order
893882
@warn """\n
@@ -980,12 +969,8 @@ otherwise, it generates a random connectivity pattern.
980969
function low_connectivity(rng::AbstractRNG, ::Type{T}, dims::Integer...;
981970
return_sparse::Bool=false, connected::Bool=false,
982971
in_degree::Integer=1, kwargs...) where {T <: Number}
972+
check_res_size(dims...)
983973
res_size = dims[1]
984-
if length(dims) != 2 || dims[1] != dims[2]
985-
error("""
986-
Internal reservoir matrix must be square. Got dims = $(dims)
987-
""")
988-
end
989974
if in_degree > res_size
990975
error("""
991976
In-degree k (got k=$(in_degree)) cannot exceed number of nodes N=$(res_size)
@@ -1113,10 +1098,7 @@ function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
11131098
weight::Union{Number, AbstractVector}=T(0.1), shift::Integer=1,
11141099
return_sparse::Bool=false, kwargs...) where {T <: Number}
11151100
throw_sparse_error(return_sparse)
1116-
@assert length(dims) == 2&&dims[1] == dims[2] """\n
1117-
The dimensions must define a square matrix
1118-
(e.g., (100, 100))
1119-
"""
1101+
check_res_size(dims...)
11201102
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
11211103
delay_line!(rng, reservoir_matrix, T.(weight), shift; kwargs...)
11221104
return return_init_as(Val(return_sparse), reservoir_matrix)
@@ -1207,6 +1189,7 @@ function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
12071189
delay_kwargs::NamedTuple=NamedTuple(),
12081190
fb_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
12091191
throw_sparse_error(return_sparse)
1192+
check_res_size(dims...)
12101193
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
12111194
delay_line!(rng, reservoir_matrix, T.(weight), shift; delay_kwargs...)
12121195
backward_connection!(rng, reservoir_matrix, T.(fb_weight), fb_shift; fb_kwargs...)
@@ -1295,6 +1278,7 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
12951278
cycle_kwargs::NamedTuple=NamedTuple(),
12961279
jump_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
12971280
throw_sparse_error(return_sparse)
1281+
check_res_size(dims...)
12981282
res_size = first(dims)
12991283
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
13001284
simple_cycle!(rng, reservoir_matrix, T.(cycle_weight); cycle_kwargs...)
@@ -1369,6 +1353,7 @@ function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13691353
weight::Union{Number, AbstractVector}=T(0.1),
13701354
return_sparse::Bool=false, kwargs...) where {T <: Number}
13711355
throw_sparse_error(return_sparse)
1356+
check_res_size(dims...)
13721357
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
13731358
simple_cycle!(rng, reservoir_matrix, T.(weight); kwargs...)
13741359
return return_init_as(Val(return_sparse), reservoir_matrix)
@@ -1418,6 +1403,7 @@ function double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
14181403
second_cycle_weight::Union{Number, AbstractVector}=T(0.1),
14191404
return_sparse::Bool=false) where {T <: Number}
14201405
throw_sparse_error(return_sparse)
1406+
check_res_size(dims...)
14211407
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
14221408

14231409
for uidx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
@@ -1500,6 +1486,7 @@ function true_double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15001486
return_sparse::Bool=false, cycle_kwargs::NamedTuple=NamedTuple(),
15011487
second_cycle_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
15021488
throw_sparse_error(return_sparse)
1489+
check_res_size(dims...)
15031490
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
15041491
simple_cycle!(rng, reservoir_matrix, cycle_weight; cycle_kwargs...)
15051492
reverse_simple_cycle!(
@@ -1594,6 +1581,7 @@ function selfloop_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15941581
selfloop_weight::Union{Number, AbstractVector}=T(0.1f0),
15951582
return_sparse::Bool=false, kwargs...) where {T <: Number}
15961583
throw_sparse_error(return_sparse)
1584+
check_res_size(dims...)
15971585
reservoir_matrix = simple_cycle(rng, T, dims...;
15981586
weight=T.(cycle_weight), return_sparse=false)
15991587
self_loop!(rng, reservoir_matrix, T.(selfloop_weight); kwargs...)
@@ -1671,6 +1659,7 @@ function selfloop_feedback_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
16711659
selfloop_weight::Union{Number, AbstractVector}=T(0.1f0),
16721660
return_sparse::Bool=false) where {T <: Number}
16731661
throw_sparse_error(return_sparse)
1662+
check_res_size(dims...)
16741663
reservoir_matrix = simple_cycle(rng, T, dims...;
16751664
weight=T.(cycle_weight), return_sparse=false)
16761665
for idx in axes(reservoir_matrix, 1)
@@ -1788,6 +1777,7 @@ function selfloop_delayline_backward(rng::AbstractRNG, ::Type{T}, dims::Integer.
17881777
fb_kwargs::NamedTuple=NamedTuple(),
17891778
selfloop_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
17901779
throw_sparse_error(return_sparse)
1780+
check_res_size(dims...)
17911781
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
17921782
self_loop!(rng, reservoir_matrix, T.(selfloop_weight); selfloop_kwargs...)
17931783
delay_line!(rng, reservoir_matrix, T.(weight), shift; delay_kwargs...)
@@ -1886,6 +1876,7 @@ function selfloop_forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer.
18861876
return_sparse::Bool=false, delay_kwargs::NamedTuple=NamedTuple(),
18871877
selfloop_kwargs::NamedTuple=NamedTuple()) where {T <: Number}
18881878
throw_sparse_error(return_sparse)
1879+
check_res_size(dims...)
18891880
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
18901881
self_loop!(rng, reservoir_matrix, T.(selfloop_weight); selfloop_kwargs...)
18911882
delay_line!(rng, reservoir_matrix, T.(weight), shift; delay_kwargs...)
@@ -1971,6 +1962,7 @@ function forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer...;
19711962
weight::Union{Number, AbstractVector}=T(0.1f0), return_sparse::Bool=false,
19721963
kwargs...) where {T <: Number}
19731964
throw_sparse_error(return_sparse)
1965+
check_res_size(dims...)
19741966
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
19751967
delay_line!(rng, reservoir_matrix, T.(weight), 2; kwargs...)
19761968
return return_init_as(Val(return_sparse), reservoir_matrix)

src/esn/inits_components.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@ function throw_sparse_error(return_sparse::Bool)
1515
end
1616
end
1717

18+
function check_modified_ressize(res_size::Integer, approx_res_size::Integer)
19+
if res_size != approx_res_size
20+
@warn """Reservoir size has changed!\n
21+
Computed reservoir size ($res_size) does not equal the \
22+
provided reservoir size ($approx_res_size). \n
23+
Using computed value ($res_size). Make sure to modify the \
24+
reservoir initializer accordingly. \n
25+
"""
26+
end
27+
end
28+
29+
function check_res_size(dims::Integer...)
30+
if length(dims) != 2 || dims[1] != dims[2]
31+
error("""\n
32+
Internal reservoir matrix must be square (e.g., (100, 100)).
33+
Got dims = $(dims)\n
34+
""")
35+
end
36+
end
37+
1838
## scale spectral radius
1939
"""
2040
scale_radius!(matrix, radius)

0 commit comments

Comments
 (0)