Skip to content

Commit a268310

Browse files
better generics
1 parent cabbb38 commit a268310

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

src/esn/esn_inits.jl

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,8 @@ end
702702

703703
function delay_line!(reservoir_matrix::AbstractMatrix, weight::Number,
704704
shift::Int)
705-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - shift)
706-
reservoir_matrix[idx + shift, idx] = weight
707-
end
705+
weights = fill(weight, size(reservoir_matrix, 1) - shift)
706+
delay_line!(reservoir_matrix, weights, shift)
708707
end
709708

710709
function delay_line!(reservoir_matrix::AbstractMatrix, weight::AbstractVector,
@@ -774,9 +773,8 @@ end
774773

775774
function backward_connection!(reservoir_matrix::AbstractMatrix, weight::Number,
776775
shift::Int)
777-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - shift)
778-
reservoir_matrix[idx, idx + shift] = weight
779-
end
776+
weights = fill(weight, size(reservoir_matrix, 1) - shift)
777+
backward_connection!(reservoir_matrix, weights, shift)
780778
end
781779

782780
function backward_connection!(reservoir_matrix::AbstractMatrix, weight::AbstractVector,
@@ -842,17 +840,19 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
842840
res_size = first(dims)
843841
reservoir_matrix = DeviceAgnostic.zeros(rng, T, dims...)
844842
simple_cycle!(reservoir_matrix, cycle_weight)
843+
add_jumps!(reservoir_matrix, cycle_weight, jump_size)
844+
return return_init_as(Val(return_sparse), reservoir_matrix)
845+
end
845846

846-
for idx in 1:jump_size:(res_size - jump_size)
847-
tmp = (idx + jump_size) % res_size
847+
function add_jumps!(reservoir_matrix::AbstractMatrix, weight::Number, jump_size::Int)
848+
for idx in 1:jump_size:(size(reservoir_matrix, 1) - jump_size)
849+
tmp = (idx + jump_size) % size(reservoir_matrix, 1)
848850
if tmp == 0
849-
tmp = res_size
851+
tmp = size(reservoir_matrix, 1)
850852
end
851-
reservoir_matrix[idx, tmp] = T(cycle_weight)
852-
reservoir_matrix[tmp, idx] = T(cycle_weight)
853+
reservoir_matrix[idx, tmp] = weight
854+
reservoir_matrix[tmp, idx] = weight
853855
end
854-
855-
return return_init_as(Val(return_sparse), reservoir_matrix)
856856
end
857857

858858
"""
@@ -908,10 +908,8 @@ function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
908908
end
909909

910910
function simple_cycle!(reservoir_matrix::AbstractMatrix, weight::Number)
911-
for idx in first(axes(reservoir_matrix, 1)):(last(axes(reservoir_matrix, 1)) - 1)
912-
reservoir_matrix[idx + 1, idx] = weight
913-
end
914-
reservoir_matrix[1, end] = weight
911+
weights = fill(weight, size(reservoir_matrix, 1))
912+
simple_cycle!(reservoir_matrix, weights)
915913
end
916914

917915
function simple_cycle!(reservoir_matrix::AbstractMatrix, weight::AbstractVector)

0 commit comments

Comments
 (0)