1+ abstract type AbstractEchoStateNetworkCell <: AbstractReservoirRecurrentCell end
2+
13@doc raw """
24 ESNCell(in_dims => out_dims, [activation];
35 use_bias=false, init_bias=rand32,
@@ -63,7 +65,7 @@ Created by `initialstates(rng, esn)`:
6365
6466 - `rng`: a replicated RNG used to sample initial hidden states when needed.
6567"""
66- @concrete struct ESNCell <: AbstractReservoirRecurrentCell
68+ @concrete struct ESNCell <: AbstractEchoStateNetworkCell
6769 activation:: Any
6870 in_dims <: IntegerType
6971 out_dims <: IntegerType
@@ -76,15 +78,15 @@ Created by `initialstates(rng, esn)`:
7678 use_bias <: StaticBool
7779end
7880
79- function ESNCell (
80- (in_dims, out_dims) :: Pair{<:IntegerType, <:IntegerType} , activation = tanh;
81- use_bias :: BoolType = False (), init_bias = zeros32, init_reservoir = rand_sparse ,
82- init_input = scaled_rand, init_state = randn32, leak_coefficient = 1.0 )
81+ function ESNCell ((in_dims, out_dims) :: Pair{<:IntegerType, <:IntegerType} ,
82+ activation = tanh; use_bias :: BoolType = False (), init_bias = zeros32,
83+ init_reservoir = rand_sparse, init_input = scaled_rand ,
84+ init_state = randn32, leak_coefficient = 1.0 )
8385 return ESNCell (activation, in_dims, out_dims, init_bias, init_reservoir,
8486 init_input, init_state, leak_coefficient, use_bias)
8587end
8688
87- function initialparameters (rng:: AbstractRNG , esn:: ESNCell )
89+ function initialparameters (rng:: AbstractRNG , esn:: AbstractEchoStateNetworkCell )
8890 ps = (input_matrix = esn. init_input (rng, esn. out_dims, esn. in_dims),
8991 reservoir_matrix = esn. init_reservoir (rng, esn. out_dims, esn. out_dims))
9092 if has_bias (esn)
@@ -93,11 +95,11 @@ function initialparameters(rng::AbstractRNG, esn::ESNCell)
9395 return ps
9496end
9597
96- function initialstates (rng:: AbstractRNG , esn:: ESNCell )
98+ function initialstates (rng:: AbstractRNG , esn:: AbstractEchoStateNetworkCell )
9799 return (rng = sample_replicate (rng),)
98100end
99101
100- function (esn:: ESNCell )(inp:: AbstractArray , ps, st:: NamedTuple )
102+ function (esn:: AbstractEchoStateNetworkCell )(inp:: AbstractArray , ps, st:: NamedTuple )
101103 rng = replicate (st. rng)
102104 hidden_state = init_hidden_state (rng, esn, inp)
103105 return esn ((inp, (hidden_state,)), ps, merge (st, (; rng)))
0 commit comments