Skip to content

Commit 95b5771

Browse files
addressing type conversion in inits
1 parent f9cde1e commit 95b5771

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

src/esn/esn_inits.jl

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ end
55

66
# error for sparse inits with no SparseArrays.jl call
77

8-
function throw_sparse_error(return_sparse)
8+
function throw_sparse_error(return_sparse::Bool)
99
if return_sparse && !haskey(Base.loaded_modules, :SparseArrays)
1010
error("""\n
1111
Sparse output requested but SparseArrays.jl is not loaded.
@@ -18,7 +18,7 @@ end
1818

1919
## scale spectral radius
2020

21-
function scale_radius!(reservoir_matrix, radius)
21+
function scale_radius!(reservoir_matrix::AbstractMatrix, radius::AbstractFloat)
2222
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
2323
reservoir_matrix .*= radius / rho_w
2424
if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
@@ -27,6 +27,7 @@ function scale_radius!(reservoir_matrix, radius)
2727
Increase res_size or increase sparsity.\n
2828
""")
2929
end
30+
return reservoir_matrix
3031
end
3132

3233
### input layers
@@ -70,7 +71,7 @@ function scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...;
7071
scaling=T(0.1)) where {T <: Number}
7172
res_size, in_size = dims
7273
layer_matrix = (DeviceAgnostic.rand(rng, T, res_size, in_size) .- T(0.5)) .*
73-
(T(2) * scaling)
74+
(T(2) * T(scaling))
7475
return layer_matrix
7576
end
7677

@@ -124,9 +125,8 @@ function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
124125
q = floor(Int, res_size / in_size)
125126

126127
for idx in 1:in_size
127-
layer_matrix[((idx - 1) * q + 1):((idx) * q), idx] = (DeviceAgnostic.rand(
128-
rng, T, q) .-
129-
T(0.5)) .* (T(2) * scaling)
128+
sc_rand = (DeviceAgnostic.rand(rng, T, q) .- T(0.5)) .* (T(2) * T(scaling))
129+
layer_matrix[((idx - 1) * q + 1):((idx) * q), idx] = sc_rand
130130
end
131131

132132
return return_init_as(Val(return_sparse), layer_matrix)
@@ -179,7 +179,7 @@ function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
179179
random_row_idx = idxs[DeviceAgnostic.rand(rng, T, 1:end)]
180180
random_clm_idx = range(1, state_size; step=1)[DeviceAgnostic.rand(rng, T, 1:end)]
181181
input_matrix[random_row_idx, random_clm_idx] = (DeviceAgnostic.rand(rng, T) -
182-
T(0.5)) .* (T(2) * scaling)
182+
T(0.5)) .* (T(2) * T(scaling))
183183
end
184184

185185
for idx in 1:num_for_model
@@ -189,7 +189,7 @@ function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
189189
random_clm_idx = range(state_size + 1, in_size; step=1)[DeviceAgnostic.rand(
190190
rng, T, 1:end)]
191191
input_matrix[random_row_idx, random_clm_idx] = (DeviceAgnostic.rand(rng, T) -
192-
T(0.5)) .* (T(2) * scaling)
192+
T(0.5)) .* (T(2) * T(scaling))
193193
end
194194

195195
return input_matrix
@@ -287,9 +287,9 @@ function bernoulli(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
287287
for idx in 1:res_size
288288
for jdx in 1:in_size
289289
if DeviceAgnostic.rand(rng, T) < p
290-
input_matrix[idx, jdx] = weight
290+
input_matrix[idx, jdx] = T(weight)
291291
else
292-
input_matrix[idx, jdx] = -weight
292+
input_matrix[idx, jdx] = -T(weight)
293293
end
294294
end
295295
end
@@ -312,7 +312,7 @@ function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
312312
for idx in 1:res_size
313313
for jdx in 1:in_size
314314
random_number = DeviceAgnostic.rand(rng, T)
315-
input_matrix[idx, jdx] = random_number < 0.5 ? -weight : weight
315+
input_matrix[idx, jdx] = random_number < 0.5 ? -T(weight) : T(weight)
316316
end
317317
end
318318

@@ -638,7 +638,7 @@ function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
638638
throw_sparse_error(return_sparse)
639639
lcl_sparsity = T(1) - sparsity #consistency with current implementations
640640
reservoir_matrix = sparse_init(rng, T, dims...; sparsity=lcl_sparsity, std=std)
641-
scale_radius!(reservoir_matrix, radius)
641+
reservoir_matrix = scale_radius!(reservoir_matrix, T(radius))
642642
return return_init_as(Val(return_sparse), reservoir_matrix)
643643
end
644644

@@ -697,7 +697,7 @@ function delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...;
697697
"""
698698

699699
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
700-
reservoir_matrix[idx + 1, idx] = weight
700+
reservoir_matrix[idx + 1, idx] = T(weight)
701701
end
702702

703703
return return_init_as(Val(return_sparse), reservoir_matrix)
@@ -756,8 +756,8 @@ function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
756756
throw_sparse_error(return_sparse)
757757
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
758758
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
759-
reservoir_matrix[idx + 1, idx] = weight
760-
reservoir_matrix[idx, idx + 1] = fb_weight
759+
reservoir_matrix[idx + 1, idx] = T(weight)
760+
reservoir_matrix[idx, idx + 1] = T(fb_weight)
761761
end
762762
return return_init_as(Val(return_sparse), reservoir_matrix)
763763
end
@@ -819,18 +819,18 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
819819
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
820820

821821
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
822-
reservoir_matrix[idx + 1, idx] = cycle_weight
822+
reservoir_matrix[idx + 1, idx] = T(cycle_weight)
823823
end
824824

825-
reservoir_matrix[1, res_size] = cycle_weight
825+
reservoir_matrix[1, res_size] = T(cycle_weight)
826826

827827
for idx in 1:jump_size:(res_size - jump_size)
828828
tmp = (idx + jump_size) % res_size
829829
if tmp == 0
830830
tmp = res_size
831831
end
832-
reservoir_matrix[idx, tmp] = jump_weight
833-
reservoir_matrix[tmp, idx] = jump_weight
832+
reservoir_matrix[idx, tmp] = T(cycle_weight)
833+
reservoir_matrix[tmp, idx] = T(cycle_weight)
834834
end
835835

836836
return return_init_as(Val(return_sparse), reservoir_matrix)
@@ -947,7 +947,7 @@ function pseudo_svd(rng::AbstractRNG, ::Type{T}, dims::Integer...;
947947
return_diag::Bool=false) where {T <: Number}
948948
throw_sparse_error(return_sparse)
949949
reservoir_matrix = create_diag(rng, T, dims[1],
950-
max_value;
950+
T(max_value);
951951
sorted=sorted,
952952
reverse_sort=reverse_sort)
953953
tmp_sparsity = get_sparsity(reservoir_matrix, dims[1])
@@ -1003,10 +1003,10 @@ function create_qmatrix(rng::AbstractRNG, ::Type{T}, dim::Number,
10031003
qmatrix[idx, idx] = 1.0
10041004
end
10051005

1006-
qmatrix[coord_i, coord_i] = cos(theta)
1007-
qmatrix[coord_j, coord_j] = cos(theta)
1008-
qmatrix[coord_i, coord_j] = -sin(theta)
1009-
qmatrix[coord_j, coord_i] = sin(theta)
1006+
qmatrix[coord_i, coord_i] = cos(T(theta))
1007+
qmatrix[coord_j, coord_j] = cos(T(theta))
1008+
qmatrix[coord_i, coord_j] = -sin(T(theta))
1009+
qmatrix[coord_j, coord_i] = sin(T(theta))
10101010
return qmatrix
10111011
end
10121012

@@ -1192,7 +1192,7 @@ function build_cycle(::Val{false}, rng::AbstractRNG, ::Type{T}, res_size::Int;
11921192
reservoir_matrix[idx, jdx] = T(randn(rng))
11931193
end
11941194
end
1195-
scale_radius!(reservoir_matrix, radius)
1195+
reservoir_matrix = scale_radius!(reservoir_matrix, T(radius))
11961196
return reservoir_matrix
11971197
end
11981198

@@ -1204,7 +1204,7 @@ function build_cycle(::Val{true}, rng::AbstractRNG, ::Type{T}, res_size::Int;
12041204
reservoir_matrix[perm[idx], perm[idx + 1]] = T(randn(rng))
12051205
end
12061206
reservoir_matrix[perm[res_size], perm[1]] = T(randn(rng))
1207-
scale_radius!(reservoir_matrix, radius)
1207+
reservoir_matrix = scale_radius!(reservoir_matrix, T(radius))
12081208
if cut_cycle
12091209
cut_cycle_edge!(reservoir_matrix, rng)
12101210
end
@@ -1270,14 +1270,14 @@ function double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
12701270
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
12711271

12721272
for uidx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
1273-
reservoir_matrix[uidx + 1, uidx] = cycle_weight
1273+
reservoir_matrix[uidx + 1, uidx] = T(cycle_weight)
12741274
end
12751275
for lidx in (first(axes(reservoir_matrix, 1)) + 1):last(axes(reservoir_matrix, 1))
1276-
reservoir_matrix[lidx - 1, lidx] = second_cycle_weight
1276+
reservoir_matrix[lidx - 1, lidx] = T(second_cycle_weight)
12771277
end
12781278

1279-
reservoir_matrix[1, dims[1]] = second_cycle_weight
1280-
reservoir_matrix[dims[1], 1] = cycle_weight
1279+
reservoir_matrix[1, dims[1]] = T(second_cycle_weight)
1280+
reservoir_matrix[dims[1], 1] = T(cycle_weight)
12811281
return return_init_as(Val(return_sparse), reservoir_matrix)
12821282
end
12831283

@@ -1347,7 +1347,7 @@ function selfloop_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13471347
return_sparse::Bool=false) where {T <: Number}
13481348
throw_sparse_error(return_sparse)
13491349
reservoir_matrix = simple_cycle(rng, T, dims...;
1350-
weight=cycle_weight, return_sparse=false)
1350+
weight=T(cycle_weight), return_sparse=false)
13511351
reservoir_matrix += T(selfloop_weight) .* I(dims[1])
13521352
return return_init_as(Val(return_sparse), reservoir_matrix)
13531353
end

0 commit comments

Comments
 (0)