Skip to content

Commit 4b81176

Browse files
streamlining inits to WeightInitializers v1
1 parent a79dd1a commit 4b81176

File tree

4 files changed

+56
-55
lines changed

4 files changed

+56
-55
lines changed

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ version = "0.10.4"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
99
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
10-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
13-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1412
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
1513
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1614
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -30,18 +28,16 @@ Aqua = "0.8"
3028
CellularAutomata = "0.0.2"
3129
DifferentialEquations = "7"
3230
Distances = "0.10"
33-
Distributions = "0.25.36"
3431
LIBSVM = "0.8"
3532
LinearAlgebra = "1.10"
3633
MLJLinearModels = "0.9.2, 0.10"
3734
NNlib = "0.8.4, 0.9"
38-
Optim = "1"
3935
PartialFunctions = "1.2"
4036
Random = "1.10"
4137
SafeTestsets = "0.1"
4238
Statistics = "1.10"
4339
Test = "1"
44-
WeightInitializers = "0.1.6"
40+
WeightInitializers = "1.0.4"
4541
julia = "1.10"
4642

4743
[extras]
@@ -52,4 +48,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5248
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5349

5450
[targets]
55-
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"]
51+
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations"]

src/ReservoirComputing.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Optim
1010
using PartialFunctions
1111
using Random
1212
using Statistics
13-
using WeightInitializers
13+
using WeightInitializers: WeightInitializers, DeviceAgnostic
1414

1515
export NLADefault, NLAT1, NLAT2, NLAT3
1616
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
@@ -106,30 +106,35 @@ end
106106

107107
__partial_apply(fn, inp) = fn$inp
108108

109-
#fallbacks for initializers
109+
#fallbacks for initializers #eventually to remove once migrated to WeightInitializers.jl
110110
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
111111
:simple_cycle, :pseudo_svd,
112112
:scaled_rand, :weighted_init, :informed_init, :minimal_init)
113-
NType = ifelse(initializer === :rand_sparse, Real, Number)
114-
@eval function ($initializer)(dims::Integer...; kwargs...)
115-
return $initializer(WeightInitializers._default_rng(), Float32, dims...; kwargs...)
113+
@eval begin
114+
function ($initializer)(dims::Integer...; kwargs...)
115+
return $initializer(Utils.default_rng(), Float32, dims...; kwargs...)
116+
end
117+
function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
118+
return $initializer(rng, Float32, dims...; kwargs...)
119+
end
120+
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: Number}
121+
return $initializer(Utils.default_rng(), T, dims...; kwargs...)
122+
end
123+
124+
# Partial application
125+
function ($initializer)(rng::AbstractRNG; kwargs...)
126+
return PartialFunction.Partial{Nothing}($initializer, rng, kwargs)
127+
end
128+
function ($initializer)(::Type{T}; kwargs...) where {T <: Number}
129+
return PartialFunction.Partial{T}($initializer, nothing, kwargs)
130+
end
131+
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number}
132+
return PartialFunction.Partial{T}($initializer, rng, kwargs)
133+
end
134+
function ($initializer)(; kwargs...)
135+
return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs)
136+
end
116137
end
117-
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
118-
return $initializer(rng, Float32, dims...; kwargs...)
119-
end
120-
@eval function ($initializer)(::Type{T},
121-
dims::Integer...; kwargs...) where {T <: $NType}
122-
return $initializer(WeightInitializers._default_rng(), T, dims...; kwargs...)
123-
end
124-
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
125-
return __partial_apply($initializer, (rng, (; kwargs...)))
126-
end
127-
@eval function ($initializer)(rng::AbstractRNG,
128-
::Type{T}; kwargs...) where {T <: $NType}
129-
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
130-
end
131-
@eval ($initializer)(; kwargs...) = __partial_apply(
132-
$initializer, (; kwargs...))
133138
end
134139

135140
#general

src/esn/esn_input_layers.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function scaled_rand(rng::AbstractRNG,
2626
dims::Integer...;
2727
scaling = T(0.1)) where {T <: Number}
2828
res_size, in_size = dims
29-
layer_matrix = T.(rand(rng, Uniform(-scaling, scaling), res_size, in_size))
29+
layer_matrix = T.(DeviceAgnostic.rand(rng, Uniform(-scaling, scaling), res_size, in_size))
3030
return layer_matrix
3131
end
3232

@@ -65,11 +65,11 @@ function weighted_init(rng::AbstractRNG,
6565
scaling = T(0.1)) where {T <: Number}
6666
approx_res_size, in_size = dims
6767
res_size = Int(floor(approx_res_size / in_size) * in_size)
68-
layer_matrix = zeros(T, res_size, in_size)
68+
layer_matrix = DeviceAgnostic.zeros(T, res_size, in_size)
6969
q = floor(Int, res_size / in_size)
7070

7171
for i in 1:in_size
72-
layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng,
72+
layer_matrix[((i - 1) * q + 1):((i) * q), i] = DeviceAgnostic.rand(rng,
7373
Uniform(-scaling, scaling),
7474
q)
7575
end
@@ -113,25 +113,25 @@ function informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
113113
throw(DimensionMismatch("in_size must be greater than model_in_size"))
114114
end
115115

116-
input_matrix = zeros(res_size, in_size)
117-
zero_connections = zeros(in_size)
116+
input_matrix = DeviceAgnostic.zeros(res_size, in_size)
117+
zero_connections = DeviceAgnostic.zeros(in_size)
118118
num_for_state = floor(Int, res_size * gamma)
119119
num_for_model = floor(Int, res_size * (1 - gamma))
120120

121121
for i in 1:num_for_state
122122
idxs = findall(Bool[zero_connections .== input_matrix[i, :]
123123
for i in 1:size(input_matrix, 1)])
124-
random_row_idx = idxs[rand(rng, 1:end)]
125-
random_clm_idx = range(1, state_size, step = 1)[rand(rng, 1:end)]
126-
input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling))
124+
random_row_idx = idxs[DeviceAgnostic.rand(rng, 1:end)]
125+
random_clm_idx = range(1, state_size, step = 1)[DeviceAgnostic.rand(rng, 1:end)]
126+
input_matrix[random_row_idx, random_clm_idx] = DeviceAgnostic.rand(rng, Uniform(-scaling, scaling))
127127
end
128128

129129
for i in 1:num_for_model
130130
idxs = findall(Bool[zero_connections .== input_matrix[i, :]
131131
for i in 1:size(input_matrix, 1)])
132-
random_row_idx = idxs[rand(rng, 1:end)]
133-
random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(rng, 1:end)]
134-
input_matrix[random_row_idx, random_clm_idx] = rand(rng, Uniform(-scaling, scaling))
132+
random_row_idx = idxs[DeviceAgnostic.rand(rng, 1:end)]
133+
random_clm_idx = range(state_size + 1, in_size, step = 1)[DeviceAgnostic.rand(rng, 1:end)]
134+
input_matrix[random_row_idx, random_clm_idx] = DeviceAgnostic.rand(rng, Uniform(-scaling, scaling))
135135
end
136136

137137
return input_matrix
@@ -196,10 +196,10 @@ function _create_bernoulli(p::Number,
196196
weight::Number,
197197
rng::AbstractRNG,
198198
::Type{T}) where {T <: Number}
199-
input_matrix = zeros(T, res_size, in_size)
199+
input_matrix = DeviceAgnostic.zeros(T, res_size, in_size)
200200
for i in 1:res_size
201201
for j in 1:in_size
202-
rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) :
202+
DeviceAgnostic.rand(rng, Bernoulli(p)) ? (input_matrix[i, j] = weight) :
203203
(input_matrix[i, j] = -weight)
204204
end
205205
end
@@ -216,16 +216,16 @@ function _create_irrational(irrational::Irrational,
216216
setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1))))
217217
ir_string = string(BigFloat(irrational)) |> collect
218218
deleteat!(ir_string, findall(x -> x == '.', ir_string))
219-
ir_array = zeros(length(ir_string))
220-
input_matrix = zeros(T, res_size, in_size)
219+
ir_array = DeviceAgnostic.zeros(length(ir_string))
220+
input_matrix = DeviceAgnostic.zeros(T, res_size, in_size)
221221

222222
for i in 1:length(ir_string)
223223
ir_array[i] = parse(Int, ir_string[i])
224224
end
225225

226226
for i in 1:res_size
227227
for j in 1:in_size
228-
random_number = rand(rng, T)
228+
random_number = DeviceAgnostic.rand(rng, T)
229229
input_matrix[i, j] = random_number < 0.5 ? -weight : weight
230230
end
231231
end

src/esn/esn_reservoirs.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function delay_line(rng::AbstractRNG,
6666
::Type{T},
6767
dims::Integer...;
6868
weight = T(0.1)) where {T <: Number}
69-
reservoir_matrix = zeros(T, dims...)
69+
reservoir_matrix = DeviceAgnostic.zeros(T, dims...)
7070
@assert length(dims) == 2&&dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))"
7171

7272
for i in 1:(dims[1] - 1)
@@ -107,7 +107,7 @@ function delay_line_backward(rng::AbstractRNG,
107107
weight = T(0.1),
108108
fb_weight = T(0.2)) where {T <: Number}
109109
res_size = first(dims)
110-
reservoir_matrix = zeros(T, dims...)
110+
reservoir_matrix = DeviceAgnostic.zeros(T, dims...)
111111

112112
for i in 1:(res_size - 1)
113113
reservoir_matrix[i + 1, i] = weight
@@ -148,7 +148,7 @@ function cycle_jumps(rng::AbstractRNG,
148148
jump_weight::Number = T(0.1),
149149
jump_size::Int = 3) where {T <: Number}
150150
res_size = first(dims)
151-
reservoir_matrix = zeros(T, dims...)
151+
reservoir_matrix = DeviceAgnostic.zeros(T, dims...)
152152

153153
for i in 1:(res_size - 1)
154154
reservoir_matrix[i + 1, i] = cycle_weight
@@ -194,7 +194,7 @@ function simple_cycle(rng::AbstractRNG,
194194
::Type{T},
195195
dims::Integer...;
196196
weight = T(0.1)) where {T <: Number}
197-
reservoir_matrix = zeros(T, dims...)
197+
reservoir_matrix = DeviceAgnostic.zeros(T, dims...)
198198

199199
for i in 1:(dims[1] - 1)
200200
reservoir_matrix[i + 1, i] = weight
@@ -246,9 +246,9 @@ function pseudo_svd(rng::AbstractRNG,
246246

247247
while tmp_sparsity <= sparsity
248248
reservoir_matrix *= create_qmatrix(dims[1],
249-
rand(1:dims[1]),
250-
rand(1:dims[1]),
251-
rand(T) * T(2) - T(1),
249+
DeviceAgnostic.rand(1:dims[1]),
250+
DeviceAgnostic.rand(1:dims[1]),
251+
DeviceAgnostic.rand(T) * T(2) - T(1),
252252
T)
253253
tmp_sparsity = get_sparsity(reservoir_matrix, dims[1])
254254
end
@@ -258,17 +258,17 @@ end
258258

259259
function create_diag(dim::Number, max_value::Number, ::Type{T};
260260
sorted::Bool = true, reverse_sort::Bool = false) where {T <: Number}
261-
diagonal_matrix = zeros(T, dim, dim)
261+
diagonal_matrix = DeviceAgnostic.zeros(T, dim, dim)
262262
if sorted == true
263263
if reverse_sort == true
264-
diagonal_values = sort(rand(T, dim) .* max_value, rev = true)
264+
diagonal_values = sort(DeviceAgnostic.rand(T, dim) .* max_value, rev = true)
265265
diagonal_values[1] = max_value
266266
else
267-
diagonal_values = sort(rand(T, dim) .* max_value)
267+
diagonal_values = sort(DeviceAgnostic.rand(T, dim) .* max_value)
268268
diagonal_values[end] = max_value
269269
end
270270
else
271-
diagonal_values = rand(T, dim) .* max_value
271+
diagonal_values = DeviceAgnostic.rand(T, dim) .* max_value
272272
end
273273

274274
for i in 1:dim
@@ -283,7 +283,7 @@ function create_qmatrix(dim::Number,
283283
coord_j::Number,
284284
theta::Number,
285285
::Type{T}) where {T <: Number}
286-
qmatrix = zeros(T, dim, dim)
286+
qmatrix = DeviceAgnostic.zeros(T, dim, dim)
287287

288288
for i in 1:dim
289289
qmatrix[i, i] = 1.0

0 commit comments

Comments
 (0)