Skip to content

Commit 44fe813

Browse files
Merge pull request #250 from SciML/fm/mf
Minor refactor
2 parents 702f681 + 69ffe75 commit 44fe813

File tree

8 files changed

+25
-41
lines changed

8 files changed

+25
-41
lines changed

src/esn/deepesn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function DeepESN(train_data::AbstractArray, in_size::Int, res_size::Int; depth::
8484
input_matrix, bias_vector)
8585
train_data = train_data[:, (washout + 1):end]
8686

87-
DeepESN(res_size, train_data, nla_type, input_matrix,
87+
return DeepESN(res_size, train_data, nla_type, input_matrix,
8888
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
8989
states)
9090
end

src/esn/esn.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ function ESN(train_data::AbstractArray, in_size::Int, res_size::Int;
7676
input_matrix, bias_vector)
7777
train_data = train_data[:, (washout + 1):end]
7878

79-
ESN(res_size, train_data, nla_type, input_matrix,
79+
return ESN(res_size, train_data, nla_type, input_matrix,
8080
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
8181
states)
8282
end
8383

8484
function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction,
8585
output_layer::AbstractOutputLayer; last_state=esn.states[:, [end]],
8686
kwargs...)
87-
pred_len = prediction.prediction_len
8887
return obtain_esn_prediction(esn, prediction, last_state, output_layer;
8988
kwargs...)
9089
end

src/esn/esn_inits.jl

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -237,30 +237,15 @@ julia> res_input = minimal_init(8, 3; p=0.8)# higher p -> more positive signs
237237
```
238238
"""
239239
function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
240-
sampling_type::Symbol=:bernoulli, weight::Number=T(0.1), irrational::Real=pi,
241-
start::Int=1, p::Number=T(0.5)) where {T <: Number}
240+
sampling_type::Symbol=:bernoulli, kwargs...) where {T <: Number}
242241
res_size, in_size = dims
243-
if sampling_type == :bernoulli
244-
layer_matrix = _create_bernoulli(p, res_size, in_size, weight, rng, T)
245-
elseif sampling_type == :irrational
246-
layer_matrix = _create_irrational(irrational,
247-
start,
248-
res_size,
249-
in_size,
250-
weight,
251-
rng,
252-
T)
253-
else
254-
error("""\n
255-
Sampling type not allowed.
256-
Please use one of :bernoulli or :irrational\n
257-
""")
258-
end
242+
f_sample = getfield(@__MODULE__, sampling_type)
243+
layer_matrix = f_sample(rng, T, res_size, in_size; kwargs...)
259244
return layer_matrix
260245
end
261246

262-
function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Number,
263-
rng::AbstractRNG, ::Type{T}) where {T <: Number}
247+
function bernoulli(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
248+
weight::Number=T(0.1), p::Number=T(0.5)) where {T <: Number}
264249
input_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
265250
for i in 1:res_size
266251
for j in 1:in_size
@@ -274,9 +259,9 @@ function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Numbe
274259
return input_matrix
275260
end
276261

277-
function _create_irrational(irrational::Irrational, start::Int, res_size::Int,
278-
in_size::Int, weight::Number, rng::AbstractRNG,
279-
::Type{T}) where {T <: Number}
262+
function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
263+
irrational::Irrational=pi, start::Int=1,
264+
weight::Number=T(0.1)) where {T <: Number}
280265
setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1))))
281266
ir_string = string(BigFloat(irrational)) |> collect
282267
deleteat!(ir_string, findall(x -> x == '.', ir_string))

src/esn/esn_predict.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function obtain_esn_prediction(esn,
2525
states[:, i] = x
2626
end
2727

28-
save_states ? (output, states) : output
28+
return save_states ? (output, states) : output
2929
end
3030

3131
function obtain_esn_prediction(esn,
@@ -55,7 +55,7 @@ function obtain_esn_prediction(esn,
5555
states[:, i] = x
5656
end
5757

58-
save_states ? (output, states) : output
58+
return save_states ? (output, states) : output
5959
end
6060

6161
#prediction dispatch on esn
@@ -98,11 +98,11 @@ function allocate_outpad(hesn::HybridESN, states_type, out)
9898
end
9999

100100
function allocate_singlepadding(::AbstractPaddedStates, out)
101-
adapt(typeof(out), zeros(size(out, 1) + 1))
101+
return adapt(typeof(out), zeros(size(out, 1) + 1))
102102
end
103103
function allocate_singlepadding(::StandardStates, out)
104-
adapt(typeof(out), zeros(size(out, 1)))
104+
return adapt(typeof(out), zeros(size(out, 1)))
105105
end
106106
function allocate_singlepadding(::ExtendedStates, out)
107-
adapt(typeof(out), zeros(size(out, 1)))
107+
return adapt(typeof(out), zeros(size(out, 1)))
108108
end

src/esn/esn_reservoir_drivers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,19 @@ echo state networks (`ESN`).
101101
Defaults to 1.0.
102102
"""
103103
function RNN(; activation_function=fast_act(tanh), leaky_coefficient=1.0)
104-
RNN(activation_function, leaky_coefficient)
104+
return RNN(activation_function, leaky_coefficient)
105105
end
106106

107107
function reservoir_driver_params(rnn::RNN, args...)
108-
rnn
108+
return rnn
109109
end
110110

111111
function next_state!(out, rnn::RNN, x, y, W, W_in, b, tmp_array)
112112
mul!(tmp_array[1], W, x)
113113
mul!(tmp_array[2], W_in, y)
114114
@. tmp_array[1] = rnn.activation_function(tmp_array[1] + tmp_array[2] + b) *
115115
rnn.leaky_coefficient
116-
@. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1]
116+
return @. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1]
117117
end
118118

119119
function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array)
@@ -353,7 +353,7 @@ function obtain_gru_state!(out, variant::FullyGated, gru, x, y, W, W_in, b, tmp_
353353
mul!(tmp_array[7], W_in, y)
354354
mul!(tmp_array[8], W, tmp_array[6] .* x)
355355
@. tmp_array[9] = gru.activation_function[3](tmp_array[7] + tmp_array[8] + b)
356-
@. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
356+
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
357357
end
358358

359359
#minimal
@@ -366,5 +366,5 @@ function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_arr
366366
mul!(tmp_array[5], W, tmp_array[3] .* x)
367367
@. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + b)
368368

369-
@. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
369+
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
370370
end

src/esn/hybridesn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function HybridESN(model::KnowledgeModel, train_data::AbstractArray,
119119
input_matrix, bias_vector)
120120
train_data = train_data[:, (washout + 1):end]
121121

122-
HybridESN(res_size, train_data, model, nla_type, input_matrix,
122+
return HybridESN(res_size, train_data, model, nla_type, input_matrix,
123123
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
124124
states)
125125
end

src/predict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ on the learned relationships in the model.
6363
"""
6464
function Predictive(prediction_data::AbstractArray)
6565
prediction_len = size(prediction_data, 2)
66-
Predictive(prediction_data, prediction_len)
66+
return Predictive(prediction_data, prediction_len)
6767
end
6868

6969
function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative,

src/states.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ abstract type AbstractPaddedStates <: AbstractStates end
33
abstract type NonLinearAlgorithm end
44

55
function pad_state!(states_type::AbstractPaddedStates, x_pad, x)
6-
x_pad = vcat(fill(states_type.padding, (1, size(x, 2))), x)
6+
x_pad[1, :] .= states_type.padding
7+
x_pad[2:end, :] .= x
78
return x_pad
89
end
910

1011
function pad_state!(states_type, x_pad, x)
11-
x_pad = x
12-
return x_pad
12+
return x
1313
end
1414

1515
#states types

0 commit comments

Comments
 (0)