Skip to content

Commit 2de99aa

Browse files
feat: add scaling per single column in input layer
1 parent 837ac47 commit 2de99aa

File tree

1 file changed

+48
-3
lines changed

1 file changed

+48
-3
lines changed

src/esn/esn_inits.jl

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ a range defined by `scaling`.
1717
# Keyword arguments
1818
1919
- `scaling`: A scaling factor to define the range of the uniform distribution.
20-
The matrix elements will be randomly chosen from the
21-
range `[-scaling, scaling]`. Defaults to `0.1`.
20+
The factor can be passed in three different ways:
21+
22+
+ A single number. In this case, the matrix elements will be randomly
23+
chosen from the range `[-scaling, scaling]`. Default option, with
24+
a the scaling value set to `0.1`.
25+
+ A tuple `(lower, upper)`. The values define the range of the distribution.
26+
+ A vector. In this case, the columns will be scaled individually by the
27+
entries of the vector. The entries can be numbers or tuples, which will mirror
28+
the behavior described above.
2229
2330
# Examples
2431
@@ -33,10 +40,38 @@ julia> res_input = scaled_rand(8, 3)
3340
0.0944272 0.0679244 0.0148647
3441
-0.0799005 -0.0891089 -0.0444782
3542
-0.0970182 0.0934286 0.03553
43+
44+
julia> tt = scaled_rand(5, 3, scaling = (0.1, 0.15))
45+
5×3 Matrix{Float32}:
46+
0.13631 0.110929 0.116177
47+
0.116299 0.136038 0.119713
48+
0.11535 0.144712 0.110029
49+
0.127453 0.12657 0.147656
50+
0.139446 0.117656 0.104712
51+
```
52+
53+
Example with vector:
54+
55+
```jldoctest
56+
julia> tt = scaled_rand(5, 3, scaling = [0.1, 0.2, 0.3])
57+
5×3 Matrix{Float32}:
58+
0.0452399 -0.112565 -0.105874
59+
-0.0348047 0.0883044 -0.0634468
60+
-0.0386004 0.157698 -0.179648
61+
0.00981022 0.012559 0.271875
62+
0.0577838 -0.0587553 -0.243451
63+
64+
julia> tt = scaled_rand(5, 3, scaling = [(0.1, 0.2), (-0.2, -0.1), (0.3, 0.5)])
65+
5×3 Matrix{Float32}:
66+
0.17262 -0.178141 0.364709
67+
0.132598 -0.127924 0.378851
68+
0.1307 -0.110575 0.340117
69+
0.154905 -0.14686 0.490625
70+
0.178892 -0.164689 0.31885
3671
```
3772
"""
3873
function scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...;
39-
scaling::Union{Number, Tuple} = T(0.1)) where {T <: Number}
74+
scaling::Union{Number, Tuple, Vector} = T(0.1)) where {T <: Number}
4075
res_size, in_size = dims
4176
layer_matrix = DeviceAgnostic.rand(rng, T, res_size, in_size)
4277
apply_scale!(layer_matrix, scaling, T)
@@ -57,6 +92,16 @@ function apply_scale!(input_matrix,
5792
return input_matrix
5893
end
5994

95+
function apply_scale!(input_matrix,
96+
scaling::AbstractVector, ::Type{T}) where {T <: Number}
97+
ncols = size(input_matrix, 2)
98+
@assert length(scaling)==ncols "need one scaling per column"
99+
for (idx, col) in enumerate(eachcol(input_matrix))
100+
apply_scale!(col, scaling[idx], T)
101+
end
102+
return input_matrix
103+
end
104+
60105
"""
61106
weighted_init([rng], [T], dims...;
62107
scaling=0.1, return_sparse=false)

0 commit comments

Comments
 (0)