66
77import brainpy .math as bm
88from brainpy ._src .initialize import Normal , ZeroInit , Initializer , parameter , variable
9- from brainpy . check import is_float , is_initializer , is_string
9+ from brainpy import check
1010from brainpy .tools import to_size
1111from brainpy .types import ArrayType
1212from .base import Layer
@@ -36,8 +36,9 @@ class Reservoir(Layer):
3636 A float between 0 and 1.
3737 activation : str, callable, optional
3838 Reservoir activation function.
39+
3940 - If a str, should be a :py:mod:`brainpy.math.activations` function name.
40- - If a callable, should be an element-wise operator on tensor .
41+ - If a callable, should be an element-wise operator.
4142 activation_type : str
4243 - If "internal" (default), then leaky integration happens on states transformed
4344 by the activation function:
@@ -66,9 +67,12 @@ class Reservoir(Layer):
6667 neurons connected to other reservoir neurons, including themselves.
6768 Must be in [0, 1], by default 0.1
6869 comp_type: str
69- The connectivity type, can be "dense" or "sparse".
70+ The connectivity type, can be "dense" or "sparse", "jit".
71+
72+ - ``"dense"`` means the connectivity matrix is a dense matrix.
73+ - ``"sparse"`` means the connectivity matrix is a CSR sparse matrix.
7074 spectral_radius : float, optional
71- Spectral radius of recurrent weight matrix, by default None
75+ Spectral radius of recurrent weight matrix, by default None.
7276 noise_rec : float, optional
7377 Gain of noise applied to reservoir internal states, by default 0.0
7478 noise_in : float, optional
@@ -118,37 +122,38 @@ def __init__(
118122 self .num_unit = num_out
119123 assert num_out > 0 , f'Must be a positive integer, but we got { num_out } '
120124 self .leaky_rate = leaky_rate
121- is_float (leaky_rate , 'leaky_rate' , 0. , 1. )
122- self .activation = getattr (bm .activations , activation )
125+ check .is_float (leaky_rate , 'leaky_rate' , 0. , 1. )
126+ self .activation = getattr (bm .activations , activation ) if isinstance (activation , str ) else activation
127+ check .is_callable (self .activation , allow_none = False )
123128 self .activation_type = activation_type
124- is_string (activation_type , 'activation_type' , ['internal' , 'external' ])
129+ check . is_string (activation_type , 'activation_type' , ['internal' , 'external' ])
125130 self .rng = bm .random .default_rng (seed )
126- is_float (spectral_radius , 'spectral_radius' , allow_none = True )
131+ check . is_float (spectral_radius , 'spectral_radius' , allow_none = True )
127132 self .spectral_radius = spectral_radius
128133
129134 # initializations
130- is_initializer (Win_initializer , 'ff_initializer' , allow_none = False )
131- is_initializer (Wrec_initializer , 'rec_initializer' , allow_none = False )
132- is_initializer (b_initializer , 'bias_initializer' , allow_none = True )
135+ check . is_initializer (Win_initializer , 'ff_initializer' , allow_none = False )
136+ check . is_initializer (Wrec_initializer , 'rec_initializer' , allow_none = False )
137+ check . is_initializer (b_initializer , 'bias_initializer' , allow_none = True )
133138 self ._Win_initializer = Win_initializer
134139 self ._Wrec_initializer = Wrec_initializer
135140 self ._b_initializer = b_initializer
136141
137142 # connectivity
138- is_float (in_connectivity , 'ff_connectivity' , 0. , 1. )
139- is_float (rec_connectivity , 'rec_connectivity' , 0. , 1. )
143+ check . is_float (in_connectivity , 'ff_connectivity' , 0. , 1. )
144+ check . is_float (rec_connectivity , 'rec_connectivity' , 0. , 1. )
140145 self .ff_connectivity = in_connectivity
141146 self .rec_connectivity = rec_connectivity
142- is_string (comp_type , 'conn_type' , ['dense' , 'sparse' ])
147+ check . is_string (comp_type , 'conn_type' , ['dense' , 'sparse' , 'jit ' ])
143148 self .comp_type = comp_type
144149
145150 # noises
146- is_float (noise_in , 'noise_ff' )
147- is_float (noise_rec , 'noise_rec' )
151+ check . is_float (noise_in , 'noise_ff' )
152+ check . is_float (noise_rec , 'noise_rec' )
148153 self .noise_ff = noise_in
149154 self .noise_rec = noise_rec
150155 self .noise_type = noise_type
151- is_string (noise_type , 'noise_type' , ['normal' , 'uniform' ])
156+ check . is_string (noise_type , 'noise_type' , ['normal' , 'uniform' ])
152157
153158 # initialize feedforward weights
154159 weight_shape = (input_shape [- 1 ], self .num_unit )
@@ -170,7 +175,7 @@ def __init__(
170175 conn_mat = self .rng .random (recurrent_shape ) > self .rec_connectivity
171176 self .Wrec [conn_mat ] = 0.
172177 if self .spectral_radius is not None :
173- current_sr = max (abs (jnp .linalg .eig (self .Wrec )[0 ]))
178+ current_sr = max (abs (jnp .linalg .eig (bm . as_jax ( self .Wrec ) )[0 ]))
174179 self .Wrec *= self .spectral_radius / current_sr
175180 if self .comp_type == 'sparse' and self .rec_connectivity < 1. :
176181 self .rec_pres , self .rec_posts = jnp .where (jnp .logical_not (bm .as_jax (conn_mat )))
@@ -186,11 +191,13 @@ def __init__(
186191 def reset_state (self , batch_size = None ):
187192 self .state .value = variable (jnp .zeros , batch_size , self .output_shape )
188193
189- def update (self , sha , x ):
194+ def update (self , * args ):
190195 """Feedforward output."""
191196 # inputs
192- x = jnp .concatenate (x , axis = - 1 )
193- if self .noise_ff > 0 : x += self .noise_ff * self .rng .uniform (- 1 , 1 , x .shape )
197+ x = args [0 ] if len (args ) == 1 else args [1 ]
198+ x = bm .as_jax (x )
199+ if self .noise_ff > 0 :
200+ x += self .noise_ff * self .rng .uniform (- 1 , 1 , x .shape )
194201 if self .comp_type == 'sparse' and self .ff_connectivity < 1. :
195202 sparse = {'data' : self .Win ,
196203 'index' : (self .ff_pres , self .ff_posts ),
0 commit comments