Skip to content

Commit 4418a90

Browse files
Merge pull request #237 from SciML/fm/types
Changing defaults to Float32
2 parents f0849f3 + 9eb5c8c commit 4418a90

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

src/esn/deepesn.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ temporal features.
3737
- `input_layer`: A function or an array of functions to initialize the input
3838
matrices for each layer. Default is `scaled_rand` for each layer.
3939
- `bias`: A function or an array of functions to initialize the bias vectors
40-
for each layer. Default is `zeros64` for each layer.
40+
for each layer. Default is `zeros32` for each layer.
4141
- `reservoir`: A function or an array of functions to initialize the reservoir
4242
matrices for each layer. Default is `rand_sparse` for each layer.
4343
- `reservoir_driver`: The driving system for the reservoir.
@@ -50,8 +50,6 @@ temporal features.
5050
Default is 0.
5151
- `rng`: Random number generator used for initializing weights. Default is the package's
5252
default random number generator.
53-
- `T`: The data type for the matrices (e.g., `Float64`). Influences computational
54-
efficiency and precision.
5553
- `matrix_type`: The type of matrix used for storing the training data.
5654
Default is inferred from `train_data`.
5755
@@ -74,21 +72,21 @@ function DeepESN(train_data,
7472
res_size::Int;
7573
depth::Int=2,
7674
input_layer=fill(scaled_rand, depth),
77-
bias=fill(zeros64, depth),
75+
bias=fill(zeros32, depth),
7876
reservoir=fill(rand_sparse, depth),
7977
reservoir_driver=RNN(),
8078
nla_type=NLADefault(),
8179
states_type=StandardStates(),
8280
washout::Int=0,
8381
rng=Utils.default_rng(),
84-
T=Float64,
8582
matrix_type=typeof(train_data))
8683
if states_type isa AbstractPaddedStates
8784
in_size = size(train_data, 1) + 1
8885
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
8986
train_data)
9087
end
9188

89+
T = eltype(train_data)
9290
reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth]
9391
input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) :
9492
input_layer[i](rng, T, res_size, res_size) for i in 1:depth]

src/esn/esn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ function ESN(train_data,
4949
res_size::Int;
5050
input_layer=scaled_rand,
5151
reservoir=rand_sparse,
52-
bias=zeros64,
52+
bias=zeros32,
5353
reservoir_driver=RNN(),
5454
nla_type=NLADefault(),
5555
states_type=StandardStates(),
5656
washout=0,
5757
rng=Utils.default_rng(),
58-
T=Float32,
5958
matrix_type=typeof(train_data))
6059
if states_type isa AbstractPaddedStates
6160
in_size = size(train_data, 1) + 1
6261
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
6362
train_data)
6463
end
6564

65+
T = eltype(train_data)
6666
reservoir_matrix = reservoir(rng, T, res_size, res_size)
6767
input_matrix = input_layer(rng, T, res_size, in_size)
6868
bias_vector = bias(rng, res_size)

src/esn/hybridesn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function HybridESN(model,
121121
res_size::Int;
122122
input_layer=scaled_rand,
123123
reservoir=rand_sparse,
124-
bias=zeros64,
124+
bias=zeros32,
125125
reservoir_driver=RNN(),
126126
nla_type=NLADefault(),
127127
states_type=StandardStates(),

test/esn/deepesn.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@ const train_len = 400
77
const predict_len = 100
88
const input_data = reduce(hcat, data[1:(train_len - 1)])
99
const target_data = reduce(hcat, data[2:train_len])
10-
const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)])
10+
const test_data = reduce(hcat, data[(train_len + 1):(train_len + predict_len)])
1111
const reg = 10e-6
12-
#test_types = [Float64, Float32, Float16]
1312

14-
Random.seed!(77)
15-
res = rand_sparse(; radius=1.2, sparsity=0.1)
16-
esn = DeepESN(input_data, 1, res_size)
13+
test_types = [Float64, Float32, Float16]
14+
zeros_types = [zeros64, zeros32, zeros16]
1715

18-
output_layer = train(esn, target_data)
19-
output = esn(Generative(length(test)), output_layer)
20-
@test mean(abs.(test .- output)) ./ mean(abs.(test)) < 0.22
16+
for (tidx, t) in enumerate(test_types)
17+
Random.seed!(77)
18+
res = rand_sparse(; radius=1.2, sparsity=0.1)
19+
esn = DeepESN(t.(input_data), 1, res_size;
20+
bias=fill(zeros_types[tidx], 2))
21+
22+
output_layer = train(esn, t.(target_data))
23+
output = esn(Generative(length(test_data)), output_layer)
24+
@test eltype(output) == t
25+
end

0 commit comments

Comments
 (0)