Skip to content

Commit e7875d2

Browse files
Merge pull request #289 from SciML/fm/blockdiagonal
feat: add block_diagonal initializer
2 parents db51b6e + 588efd3 commit e7875d2

File tree

5 files changed

+148
-41
lines changed

5 files changed

+148
-41
lines changed

docs/src/api/inits.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,47 @@
33
## Input layers
44

55
```@docs
6-
scaled_rand
7-
weighted_init
8-
minimal_init
9-
weighted_minimal
106
chebyshev_mapping
7+
informed_init
118
logistic_mapping
9+
minimal_init
1210
modified_lm
13-
informed_init
11+
scaled_rand
12+
weighted_init
13+
weighted_minimal
1414
```
1515

1616
## Reservoirs
1717

1818
```@docs
19-
rand_sparse
20-
pseudo_svd
19+
block_diagonal
2120
chaotic_init
22-
low_connectivity
21+
cycle_jumps
2322
delay_line
2423
delay_line_backward
25-
simple_cycle
26-
cycle_jumps
2724
double_cycle
28-
true_double_cycle
25+
forward_connection
26+
low_connectivity
27+
pseudo_svd
28+
rand_sparse
2929
selfloop_cycle
30-
selfloop_feedback_cycle
3130
selfloop_delayline_backward
31+
selfloop_feedback_cycle
3232
selfloop_forward_connection
33-
forward_connection
33+
simple_cycle
34+
true_double_cycle
3435
```
3536

3637
## Building functions
3738

3839
```@docs
39-
scale_radius!
40-
delay_line!
40+
add_jumps!
4141
backward_connection!
42-
simple_cycle!
42+
delay_line!
4343
reverse_simple_cycle!
44+
scale_radius!
4445
self_loop!
45-
add_jumps!
46+
simple_cycle!
4647
```
4748

4849
## References

docs/src/refs.bib

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,18 @@ @article{Herteux2020
326326
author = {Herteux, Joschka and R\"{a}th, Christoph},
327327
year = {2020},
328328
month = dec
329+
}
330+
331+
@article{Ma2023,
332+
title = {Efficient forecasting of chaotic systems with block-diagonal and binary reservoir computing},
333+
volume = {33},
334+
ISSN = {1089-7682},
335+
url = {http://dx.doi.org/10.1063/5.0151290},
336+
DOI = {10.1063/5.0151290},
337+
number = {6},
338+
journal = {Chaos: An Interdisciplinary Journal of Nonlinear Science},
339+
publisher = {AIP Publishing},
340+
author = {Ma, Haochun and Prosperino, Davide and Haluszczynski, Alexander and R\"{a}th, Christoph},
341+
year = {2023},
342+
month = jun
329343
}

src/ReservoirComputing.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ include("reca/reca_input_encodings.jl")
3737
export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare, ExtendedSquare
3838
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
3939
export StandardRidge
40-
export scaled_rand, weighted_init, informed_init, minimal_init, chebyshev_mapping,
41-
logistic_mapping, modified_lm, weighted_minimal
42-
export rand_sparse, delay_line, delay_line_backward, cycle_jumps,
43-
simple_cycle, pseudo_svd, chaotic_init, low_connectivity, double_cycle,
44-
selfloop_cycle, selfloop_feedback_cycle, selfloop_delayline_backward,
45-
selfloop_forward_connection, forward_connection, true_double_cycle
46-
export scale_radius!, delay_line!, backward_connection!, simple_cycle!,
47-
reverse_simple_cycle!,
48-
add_jumps!, self_loop!
40+
export chebyshev_mapping, informed_init, logistic_mapping, minimal_init,
41+
modified_lm, scaled_rand, weighted_init, weighted_minimal
42+
export block_diagonal, chaotic_init, cycle_jumps, delay_line, delay_line_backward,
43+
double_cycle, forward_connection, low_connectivity, pseudo_svd, rand_sparse,
44+
selfloop_cycle, selfloop_delayline_backward, selfloop_feedback_cycle,
45+
selfloop_forward_connection, simple_cycle, true_double_cycle
46+
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
47+
scale_radius!, self_loop!, simple_cycle!
4948
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
5049
export train
5150
export ESN, HybridESN, KnowledgeModel, DeepESN

src/esn/esn_inits.jl

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1881,14 +1881,106 @@ function forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer...;
18811881
return return_init_as(Val(return_sparse), reservoir_matrix)
18821882
end
18831883

1884+
@doc raw"""
1885+
block_diagonal([rng], [T], dims...;
1886+
weight=1, block_size=1,
1887+
return_sparse=false)
1888+
1889+
Creates a block‐diagonal matrix consisting of square blocks of size
1890+
`block_size` along the main diagonal [Ma2023](@cite).
1891+
Each block may be filled with
1892+
- a single scalar
1893+
- a vector of per‐block weights (length = number of blocks)
1894+
1895+
# Equations
1896+
1897+
```math
1898+
W_{i,j} =
1899+
\begin{cases}
1900+
w_b, & \text{if }\left\lfloor\frac{i-1}{s}\right\rfloor = \left\lfloor\frac{j-1}{s}\right\rfloor = b,\;
1901+
s = \text{block\_size},\; b=0,\dots,nb-1, \\
1902+
0, & \text{otherwise,}
1903+
\end{cases}
1904+
```
1905+
1906+
# Arguments
1907+
1908+
- `rng`: Random number generator. Default is `Utils.default_rng()`.
1909+
- `T`: Element type of the matrix. Default is `Float32`.
1910+
- `dims`: Dimensions of the output matrix (must be two-dimensional).
1911+
1912+
# Keyword arguments
1913+
1914+
- `weight`:
1915+
- scalar: every block is filled with that value
1916+
- vector: length = number of blocks, one constant per block
1917+
Default is `1.0`.
1918+
- `block_size`: Size\(s\) of each square block on the diagonal. Default is `1.0`.
1919+
- `return_sparse`: If `true`, returns the matrix as sparse.
1920+
SparseArrays.jl must be lodead.
1921+
Default is `false`.
1922+
1923+
# Examples
1924+
1925+
```jldoctest
1926+
# 4×4 with two 2×2 blocks of 1.0
1927+
julia> W1 = block_diagonal(4, 4; block_size=2)
1928+
4×4 Matrix{Float32}:
1929+
1.0 1.0 0.0 0.0
1930+
1.0 1.0 0.0 0.0
1931+
0.0 0.0 1.0 1.0
1932+
0.0 0.0 1.0 1.0
1933+
1934+
# per-block weights [0.5, 2.0]
1935+
julia> W2 = block_diagonal(4, 4; block_size=2, weight=[0.5, 2.0])
1936+
4×4 Matrix{Float32}:
1937+
0.5 0.5 0.0 0.0
1938+
0.5 0.5 0.0 0.0
1939+
0.0 0.0 2.0 2.0
1940+
0.0 0.0 2.0 2.0
1941+
```
1942+
"""
1943+
function block_diagonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
1944+
weight::Union{Number, AbstractVector} = T(1),
1945+
block_size::Integer = 1,
1946+
return_sparse::Bool = false) where {T <: Number}
1947+
throw_sparse_error(return_sparse)
1948+
check_res_size(dims...)
1949+
n_rows, n_cols = dims
1950+
total = min(n_rows, n_cols)
1951+
num_blocks = fld(total, block_size)
1952+
remainder = total - num_blocks * block_size
1953+
if remainder != 0
1954+
@warn "\n
1955+
With block_size=$block_size on a $n_rows×$n_cols matrix,
1956+
only $num_blocks block(s) of size $block_size fit,
1957+
leaving $remainder row(s)/column(s) unused.
1958+
\n"
1959+
end
1960+
weights = isa(weight, AbstractVector) ? T.(weight) : fill(T(weight), num_blocks)
1961+
@assert length(weights)==num_blocks "
1962+
weight vector must have length = number of blocks
1963+
"
1964+
reservoir_matrix = DeviceAgnostic.zeros(rng, T, n_rows, n_cols)
1965+
for block in 1:num_blocks
1966+
row_start = (block - 1) * block_size + 1
1967+
row_end = row_start + block_size - 1
1968+
col_start = (block - 1) * block_size + 1
1969+
col_end = col_start + block_size - 1
1970+
@inbounds reservoir_matrix[row_start:row_end, col_start:col_end] .= weights[block]
1971+
end
1972+
1973+
return return_init_as(Val(return_sparse), reservoir_matrix)
1974+
end
1975+
18841976
### fallbacks
18851977
#fallbacks for initializers #eventually to remove once migrated to WeightInitializers.jl
18861978
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
18871979
:simple_cycle, :pseudo_svd, :chaotic_init, :scaled_rand, :weighted_init,
18881980
:weighted_minimal, :informed_init, :minimal_init, :chebyshev_mapping,
18891981
:logistic_mapping, :modified_lm, :low_connectivity, :double_cycle, :selfloop_cycle,
18901982
:selfloop_feedback_cycle, :selfloop_delayline_backward, :selfloop_forward_connection,
1891-
:forward_connection, :true_double_cycle)
1983+
:forward_connection, :true_double_cycle, :block_diagonal)
18921984
@eval begin
18931985
function ($initializer)(dims::Integer...; kwargs...)
18941986
return $initializer(Utils.default_rng(), Float32, dims...; kwargs...)

test/esn/test_inits.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,32 @@ end
1919

2020
ft = [Float16, Float32, Float64]
2121
reservoir_inits = [
22-
rand_sparse,
22+
block_diagonal,
23+
chaotic_init,
24+
cycle_jumps,
2325
delay_line,
2426
delay_line_backward,
25-
cycle_jumps,
26-
simple_cycle,
27-
pseudo_svd,
28-
chaotic_init,
29-
low_connectivity,
3027
double_cycle,
28+
forward_connection,
29+
low_connectivity,
30+
pseudo_svd,
31+
rand_sparse,
3132
selfloop_cycle,
32-
selfloop_feedback_cycle,
3333
selfloop_delayline_backward,
34+
selfloop_feedback_cycle,
3435
selfloop_forward_connection,
35-
forward_connection,
36+
simple_cycle,
3637
true_double_cycle
3738
]
3839
input_inits = [
39-
scaled_rand,
40-
weighted_init,
41-
weighted_minimal,
42-
minimal_init,
43-
minimal_init(; sampling_type = :irrational_sample!),
4440
chebyshev_mapping,
4541
logistic_mapping,
46-
modified_lm(; factor = 4)
42+
minimal_init,
43+
minimal_init(; sampling_type = :irrational_sample!),
44+
modified_lm(; factor = 4),
45+
scaled_rand,
46+
weighted_init,
47+
weighted_minimal
4748
]
4849

4950
@testset "Reservoir Initializers" begin

0 commit comments

Comments
 (0)