diff --git a/Project.toml b/Project.toml index f42fa37d..77004ce0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ReservoirComputing" uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294" authors = ["Francesco Martinuzzi"] -version = "0.10.11" +version = "0.10.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/src/api/states.md b/docs/src/api/states.md index 49b31bfd..2be62fa0 100644 --- a/docs/src/api/states.md +++ b/docs/src/api/states.md @@ -16,6 +16,7 @@ NLAT1 NLAT2 NLAT3 + PartialSquare ``` ## Internals diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index fd965817..396e978f 100644 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -33,7 +33,7 @@ include("esn/esn_predict.jl") include("reca/reca.jl") include("reca/reca_input_encodings.jl") -export NLADefault, NLAT1, NLAT2, NLAT3 +export NLADefault, NLAT1, NLAT2, NLAT3, PartialSquare export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates export StandardRidge export scaled_rand, weighted_init, informed_init, minimal_init, chebyshev_mapping, diff --git a/src/states.jl b/src/states.jl index a38726b2..2f0cdbdf 100644 --- a/src/states.jl +++ b/src/states.jl @@ -654,3 +654,75 @@ function (::NLAT3)(x_old::AbstractVector) return x_new end + +@doc raw""" + PartialSquare(eta) + +Implement a partial squaring of the states as described in [^barbosa2021]. + +# Equations + +```math + \begin{equation} + g(r_i) = + \begin{cases} + r_i^2, & \text{if } i \leq \eta_r N, \\ + r_i, & \text{if } i > \eta_r N. + \end{cases} + \end{equation} +``` + +# Examples + +```jldoctest +julia> ps = PartialSquare(0.6) +PartialSquare(0.6) + + +julia> x_old = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +10-element Vector{Int64}: + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + +julia> x_new = ps(x_old) +10-element Vector{Int64}: + 0 + 1 + 4 + 9 + 16 + 25 + 6 + 7 + 8 + 9 + + +[^barbosa2021]: Barbosa, Wendson AS, et al. + "Symmetry-aware reservoir computing." + Physical Review E 104.4 (2021): 045307. +""" +struct PartialSquare <: NonLinearAlgorithm + eta::Number +end + +function (ps::PartialSquare)(x_old::AbstractVector) + x_new = copy(x_old) + n_length = length(x_old) + threshold = floor(Int, ps.eta * n_length) + for idx in eachindex(x_old) + if idx <= threshold + x_new[idx] = x_old[idx]^2 + end + end + + return x_new +end diff --git a/test/test_states.jl b/test/test_states.jl index 1a191c4f..b885004b 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -8,7 +8,8 @@ test_types = [Float64, Float32, Float16] nlas = [(NLADefault(), test_array), (NLAT1(), [1, 2, 9, 4, 25, 6, 49, 8, 81]), (NLAT2(), [1, 2, 2, 4, 12, 6, 30, 8, 9]), - (NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9])] + (NLAT3(), [1, 2, 8, 4, 24, 6, 48, 8, 9]), + (PartialSquare(0.6), [1, 4, 9, 16, 25, 6, 7, 8, 9])] pes = [(StandardStates(), test_array), (PaddedStates(; padding=padding),