Skip to content

Commit 837ac47

Browse files
feat: increase specificity for scaled_rand init
1 parent 39ca9bf commit 837ac47

File tree

1 file changed

+57
-43
lines changed

1 file changed

+57
-43
lines changed

src/esn/esn_inits.jl

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,27 @@ julia> res_input = scaled_rand(8, 3)
3636
```
3737
"""
3838
function scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...;
39-
scaling::Number = T(0.1)) where {T <: Number}
39+
scaling::Union{Number, Tuple} = T(0.1)) where {T <: Number}
4040
res_size, in_size = dims
41-
layer_matrix = (DeviceAgnostic.rand(rng, T, res_size, in_size) .- T(0.5)) .*
42-
(T(2) * T(scaling))
41+
layer_matrix = DeviceAgnostic.rand(rng, T, res_size, in_size)
42+
apply_scale!(layer_matrix, scaling, T)
4343
return layer_matrix
4444
end
4545

46+
function apply_scale!(input_matrix, scaling::Number, ::Type{T}) where {T}
47+
@. input_matrix = (input_matrix - T(0.5)) * (T(2) * T(scaling))
48+
return input_matrix
49+
end
50+
51+
function apply_scale!(input_matrix,
52+
scaling::Tuple{<:Number, <:Number}, ::Type{T}) where {T}
53+
lower, upper = T(scaling[1]), T(scaling[2])
54+
@assert lower<upper "lower < upper required"
55+
scale = upper - lower
56+
@. input_matrix = input_matrix * scale + lower
57+
return input_matrix
58+
end
59+
4660
"""
4761
weighted_init([rng], [T], dims...;
4862
scaling=0.1, return_sparse=false)
@@ -146,11 +160,11 @@ warning.
146160
```jldoctest
147161
julia> res_input = weighted_minimal(8, 3)
148162
┌ Warning: Reservoir size has changed!
149-
150-
│ Computed reservoir size (6) does not equal the provided reservoir size (8).
151-
152-
│ Using computed value (6). Make sure to modify the reservoir initializer accordingly.
153-
163+
164+
│ Computed reservoir size (6) does not equal the provided reservoir size (8).
165+
166+
│ Using computed value (6). Make sure to modify the reservoir initializer accordingly.
167+
154168
└ @ ReservoirComputing ~/.julia/dev/ReservoirComputing/src/esn/esn_inits.jl:159
155169
6×3 Matrix{Float32}:
156170
0.1 0.0 0.0
@@ -370,7 +384,7 @@ using a sine function and subsequent rows are iteratively generated
370384
via the Chebyshev mapping. The first row is defined as:
371385
372386
```math
373-
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi / (\text{sine_divisor}
387+
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi / (\text{sine_divisor}
374388
\cdot \text{n_cols}))
375389
```
376390
@@ -448,7 +462,7 @@ Generate an input weight matrix using a logistic mapping [Wang2022](@cite)
448462
The first row is initialized using a sine function:
449463
450464
```math
451-
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi /
465+
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi /
452466
(\text{sine_divisor} \cdot in_size))
453467
```
454468
@@ -527,7 +541,7 @@ as follows:
527541
- The first element of the chain is initialized using a sine function:
528542
529543
```math
530-
W[1,j] = \text{amplitude} \cdot \sin( (j \cdot \pi) /
544+
W[1,j] = \text{amplitude} \cdot \sin( (j \cdot \pi) /
531545
(\text{factor} \cdot \text{n} \cdot \text{sine_divisor}) )
532546
```
533547
where `j` is the index corresponding to the input and `n` is the number of inputs.
@@ -540,7 +554,7 @@ as follows:
540554
541555
The resulting matrix has dimensions `(factor * in_size) x in_size`, where
542556
`in_size` corresponds to the number of columns provided in `dims`.
543-
If the provided number of rows does not match `factor * in_size`
557+
If the provided number of rows does not match `factor * in_size`
544558
the number of rows is overridden.
545559
546560
# Arguments
@@ -576,15 +590,15 @@ julia> modified_lm(20, 10; factor=2)
576590
577591
julia> modified_lm(12, 4; factor=3)
578592
12×4 SparseArrays.SparseMatrixCSC{Float32, Int64} with 9 stored entries:
579-
⋅ ⋅ ⋅ ⋅
580-
⋅ ⋅ ⋅ ⋅
581-
⋅ ⋅ ⋅ ⋅
582-
⋅ 0.0133075 ⋅ ⋅
583-
⋅ 0.0308564 ⋅ ⋅
584-
⋅ 0.070275 ⋅ ⋅
585-
⋅ ⋅ 0.0265887 ⋅
586-
⋅ ⋅ 0.0608222 ⋅
587-
⋅ ⋅ 0.134239 ⋅
593+
⋅ ⋅ ⋅ ⋅
594+
⋅ ⋅ ⋅ ⋅
595+
⋅ ⋅ ⋅ ⋅
596+
⋅ 0.0133075 ⋅ ⋅
597+
⋅ 0.0308564 ⋅ ⋅
598+
⋅ 0.070275 ⋅ ⋅
599+
⋅ ⋅ 0.0265887 ⋅
600+
⋅ ⋅ 0.0608222 ⋅
601+
⋅ ⋅ 0.134239 ⋅
588602
⋅ ⋅ ⋅ 0.0398177
589603
⋅ ⋅ ⋅ 0.0898457
590604
⋅ ⋅ ⋅ 0.192168
@@ -671,7 +685,7 @@ function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
671685
end
672686

673687
"""
674-
pseudo_svd([rng], [T], dims...;
688+
pseudo_svd([rng], [T], dims...;
675689
max_value=1.0, sparsity=0.1, sorted=true, reverse_sort=false,
676690
return_sparse=false)
677691
@@ -821,15 +835,15 @@ closest valid order is used.
821835
822836
```jldoctest
823837
julia> res_matrix = chaotic_init(8, 8)
824-
┌ Warning:
825-
838+
┌ Warning:
839+
826840
│ Adjusting reservoir matrix order:
827841
│ from 8 (requested) to 4
828-
│ based on computed bit precision = 1.
829-
842+
│ based on computed bit precision = 1.
843+
830844
└ @ ReservoirComputing ~/.julia/dev/ReservoirComputing/src/esn/esn_inits.jl:805
831845
4×4 SparseArrays.SparseMatrixCSC{Float32, Int64} with 6 stored entries:
832-
⋅ -0.600945 ⋅ ⋅
846+
⋅ -0.600945 ⋅ ⋅
833847
⋅ ⋅ 0.132667 2.21354
834848
⋅ -2.60383 ⋅ -2.90391
835849
-0.578156 ⋅ ⋅ ⋅
@@ -1148,7 +1162,7 @@ function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
11481162
end
11491163

11501164
"""
1151-
cycle_jumps([rng], [T], dims...;
1165+
cycle_jumps([rng], [T], dims...;
11521166
cycle_weight=0.1, jump_weight=0.1, jump_size=3, return_sparse=false,
11531167
cycle_kwargs=(), jump_kwargs=())
11541168
@@ -1234,7 +1248,7 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
12341248
end
12351249

12361250
"""
1237-
simple_cycle([rng], [T], dims...;
1251+
simple_cycle([rng], [T], dims...;
12381252
weight=0.1, return_sparse=false,
12391253
kwargs...)
12401254
@@ -1303,7 +1317,7 @@ function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13031317
end
13041318

13051319
"""
1306-
double_cycle([rng], [T], dims...;
1320+
double_cycle([rng], [T], dims...;
13071321
cycle_weight=0.1, second_cycle_weight=0.1,
13081322
return_sparse=false)
13091323
@@ -1358,7 +1372,7 @@ function double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13581372
end
13591373

13601374
"""
1361-
true_double_cycle([rng], [T], dims...;
1375+
true_double_cycle([rng], [T], dims...;
13621376
cycle_weight=0.1, second_cycle_weight=0.1,
13631377
return_sparse=false)
13641378
@@ -1427,7 +1441,7 @@ function true_double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
14271441
end
14281442

14291443
@doc raw"""
1430-
selfloop_cycle([rng], [T], dims...;
1444+
selfloop_cycle([rng], [T], dims...;
14311445
cycle_weight=0.1, selfloop_weight=0.1,
14321446
return_sparse=false, kwargs...)
14331447
@@ -1518,7 +1532,7 @@ function selfloop_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15181532
end
15191533

15201534
@doc raw"""
1521-
selfloop_feedback_cycle([rng], [T], dims...;
1535+
selfloop_feedback_cycle([rng], [T], dims...;
15221536
cycle_weight=0.1, selfloop_weight=0.1,
15231537
return_sparse=false)
15241538
@@ -1601,7 +1615,7 @@ function selfloop_feedback_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
16011615
end
16021616

16031617
@doc raw"""
1604-
selfloop_delayline_backward([rng], [T], dims...;
1618+
selfloop_delayline_backward([rng], [T], dims...;
16051619
weight=0.1, selfloop_weight=0.1, fb_weight=0.1,
16061620
fb_shift=2, return_sparse=false, fb_kwargs=(),
16071621
selfloop_kwargs=(), delay_kwargs=())
@@ -1707,7 +1721,7 @@ function selfloop_delayline_backward(rng::AbstractRNG, ::Type{T}, dims::Integer.
17071721
end
17081722

17091723
@doc raw"""
1710-
selfloop_forward_connection([rng], [T], dims...;
1724+
selfloop_forward_connection([rng], [T], dims...;
17111725
weight=0.1, selfloop_weight=0.1,
17121726
return_sparse=false, selfloop_kwargs=(),
17131727
delay_kwargs=())
@@ -1749,7 +1763,7 @@ W_{i,j} =
17491763
Default is 0.1.
17501764
- `return_sparse`: flag for returning a `sparse` matrix.
17511765
Default is `false`.
1752-
- `delay_kwargs` and `selfloop_kwargs`: named tuples that control the kwargs for the
1766+
- `delay_kwargs` and `selfloop_kwargs`: named tuples that control the kwargs for the
17531767
delay line weight and self loop weights respectively. The kwargs are as follows:
17541768
+ `sampling_type`: Sampling that decides the distribution of `weight` negative numbers.
17551769
If set to `:no_sample` the sign is unchanged. If set to `:bernoulli_sample!` then each
@@ -1801,7 +1815,7 @@ function selfloop_forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer.
18011815
end
18021816

18031817
@doc raw"""
1804-
forward_connection([rng], [T], dims...;
1818+
forward_connection([rng], [T], dims...;
18051819
weight=0.1, selfloop_weight=0.1,
18061820
return_sparse=false)
18071821
@@ -1887,8 +1901,8 @@ end
18871901
return_sparse=false)
18881902
18891903
Creates a block‐diagonal matrix consisting of square blocks of size
1890-
`block_size` along the main diagonal [Ma2023](@cite).
1891-
Each block may be filled with
1904+
`block_size` along the main diagonal [Ma2023](@cite).
1905+
Each block may be filled with
18921906
- a single scalar
18931907
- a vector of per‐block weights (length = number of blocks)
18941908
@@ -1897,21 +1911,21 @@ Each block may be filled with
18971911
```math
18981912
W_{i,j} =
18991913
\begin{cases}
1900-
w_b, & \text{if }\left\lfloor\frac{i-1}{s}\right\rfloor = \left\lfloor\frac{j-1}{s}\right\rfloor = b,\;
1914+
w_b, & \text{if }\left\lfloor\frac{i-1}{s}\right\rfloor = \left\lfloor\frac{j-1}{s}\right\rfloor = b,\;
19011915
s = \text{block\_size},\; b=0,\dots,nb-1, \\
19021916
0, & \text{otherwise,}
19031917
\end{cases}
19041918
```
19051919
19061920
# Arguments
19071921
1908-
- `rng`: Random number generator. Default is `Utils.default_rng()`.
1909-
- `T`: Element type of the matrix. Default is `Float32`.
1922+
- `rng`: Random number generator. Default is `Utils.default_rng()`.
1923+
- `T`: Element type of the matrix. Default is `Float32`.
19101924
- `dims`: Dimensions of the output matrix (must be two-dimensional).
19111925
19121926
# Keyword arguments
19131927
1914-
- `weight`:
1928+
- `weight`:
19151929
- scalar: every block is filled with that value
19161930
- vector: length = number of blocks, one constant per block
19171931
Default is `1.0`.

0 commit comments

Comments
 (0)