44import gc
55from typing import Union , Dict , Callable , Sequence , Optional , Tuple , Any
66
7+ import jax
78import jax .numpy as jnp
89import numpy as np
910
1819from brainpy .errors import NoImplementationError , UnsupportedError
1920from brainpy .types import ArrayType , Shape
2021
21-
22-
2322__all__ = [
2423 # general class
2524 'DynamicalSystem' ,
@@ -170,14 +169,14 @@ def register_delay(
170169 raise ValueError (f'Unknown "delay_steps" type { type (delay_step )} , only support '
171170 f'integer, array of integers, callable function, brainpy.init.Initializer.' )
172171 if delay_type == 'heter' :
173- if delay_step .dtype not in [jnp .int32 , jnp .int64 ]:
172+ if delay_step .dtype not in [bm .int32 , bm .int64 ]:
174173 raise ValueError ('Only support delay steps of int32, int64. If your '
175174 'provide delay time length, please divide the "dt" '
176175 'then provide us the number of delay steps.' )
177176 if delay_target .shape [0 ] != delay_step .shape [0 ]:
178177 raise ValueError (f'Shape is mismatched: { delay_target .shape [0 ]} != { delay_step .shape [0 ]} ' )
179178 if delay_type != 'none' :
180- max_delay_step = int (jnp .max (delay_step ))
179+ max_delay_step = int (bm .max (delay_step ))
181180
182181 # delay target
183182 if delay_type != 'none' :
@@ -207,8 +206,8 @@ def register_delay(
207206 def get_delay_data (
208207 self ,
209208 identifier : str ,
210- delay_step : Optional [Union [int , bm .Array , jnp . DeviceArray ]],
211- * indices : Union [int , slice , bm .Array , jnp . DeviceArray ],
209+ delay_step : Optional [Union [int , bm .Array , jax . Array ]],
210+ * indices : Union [int , slice , bm .Array , jax . Array ],
212211 ):
213212 """Get delay data according to the provided delay steps.
214213
@@ -230,19 +229,19 @@ def get_delay_data(
230229 return self .global_delay_data [identifier ][1 ].value
231230
232231 if identifier in self .global_delay_data :
233- if jnp .ndim (delay_step ) == 0 :
232+ if bm .ndim (delay_step ) == 0 :
234233 return self .global_delay_data [identifier ][0 ](delay_step , * indices )
235234 else :
236235 if len (indices ) == 0 :
237- indices = (jnp .arange (delay_step .size ),)
236+ indices = (bm .arange (delay_step .size ),)
238237 return self .global_delay_data [identifier ][0 ](delay_step , * indices )
239238
240239 elif identifier in self .local_delay_vars :
241- if jnp .ndim (delay_step ) == 0 :
240+ if bm .ndim (delay_step ) == 0 :
242241 return self .local_delay_vars [identifier ](delay_step )
243242 else :
244243 if len (indices ) == 0 :
245- indices = (jnp .arange (delay_step .size ),)
244+ indices = (bm .arange (delay_step .size ),)
246245 return self .local_delay_vars [identifier ](delay_step , * indices )
247246
248247 else :
@@ -878,7 +877,7 @@ def __init__(
878877 # ------------
879878 if isinstance (conn , TwoEndConnector ):
880879 self .conn = conn (pre .size , post .size )
881- elif isinstance (conn , (bm .ndarray , np .ndarray , jnp . ndarray )):
880+ elif isinstance (conn , (bm .ndarray , np .ndarray , jax . Array )):
882881 if (pre .num , post .num ) != conn .shape :
883882 raise ValueError (f'"conn" is provided as a matrix, and it is expected '
884883 f'to be an array with shape of (pre.num, post.num) = '
@@ -1157,11 +1156,11 @@ def _init_weights(
11571156 return weight , conn_mask
11581157
11591158 def _syn2post_with_all2all (self , syn_value , syn_weight ):
1160- if jnp .ndim (syn_weight ) == 0 :
1159+ if bm .ndim (syn_weight ) == 0 :
11611160 if isinstance (self .mode , bm .BatchingMode ):
1162- post_vs = jnp .sum (syn_value , keepdims = True , axis = tuple (range (syn_value .ndim ))[1 :])
1161+ post_vs = bm .sum (syn_value , keepdims = True , axis = tuple (range (syn_value .ndim ))[1 :])
11631162 else :
1164- post_vs = jnp .sum (syn_value )
1163+ post_vs = bm .sum (syn_value )
11651164 if not self .conn .include_self :
11661165 post_vs = post_vs - syn_value
11671166 post_vs = syn_weight * post_vs
@@ -1173,7 +1172,7 @@ def _syn2post_with_one2one(self, syn_value, syn_weight):
11731172 return syn_value * syn_weight
11741173
11751174 def _syn2post_with_dense (self , syn_value , syn_weight , conn_mat ):
1176- if jnp .ndim (syn_weight ) == 0 :
1175+ if bm .ndim (syn_weight ) == 0 :
11771176 post_vs = (syn_weight * syn_value ) @ conn_mat
11781177 else :
11791178 post_vs = syn_value @ (syn_weight * conn_mat )
@@ -1253,8 +1252,8 @@ def __init__(
12531252
12541253 # variables
12551254 self .V = variable (V_initializer , self .mode , self .varshape )
1256- self .input = variable (jnp .zeros , self .mode , self .varshape )
1257- self .spike = variable (lambda s : jnp .zeros (s , dtype = bool ), self .mode , self .varshape )
1255+ self .input = variable (bm .zeros , self .mode , self .varshape )
1256+ self .spike = variable (lambda s : bm .zeros (s , dtype = bool ), self .mode , self .varshape )
12581257
12591258 # function
12601259 if self .noise is None :
@@ -1271,8 +1270,8 @@ def derivative(self, V, t):
12711270
12721271 def reset_state (self , batch_size = None ):
12731272 self .V .value = variable (self ._V_initializer , batch_size , self .varshape )
1274- self .spike .value = variable (lambda s : jnp .zeros (s , dtype = bool ), batch_size , self .varshape )
1275- self .input .value = variable (jnp .zeros , batch_size , self .varshape )
1273+ self .spike .value = variable (lambda s : bm .zeros (s , dtype = bool ), batch_size , self .varshape )
1274+ self .input .value = variable (bm .zeros , batch_size , self .varshape )
12761275 for channel in self .nodes (level = 1 , include_self = False ).subset (Channel ).unique ().values ():
12771276 channel .reset_state (self .V .value , batch_size = batch_size )
12781277
@@ -1286,7 +1285,7 @@ def update(self, tdi, *args, **kwargs):
12861285 # update variables
12871286 for node in channels .values ():
12881287 node .update (tdi , self .V .value )
1289- self .spike .value = jnp .logical_and (V >= self .V_th , self .V < self .V_th )
1288+ self .spike .value = bm .logical_and (V >= self .V_th , self .V < self .V_th )
12901289 self .V .value = V
12911290
12921291 def register_implicit_nodes (self , * channels , ** named_channels ):
@@ -1295,7 +1294,7 @@ def register_implicit_nodes(self, *channels, **named_channels):
12951294
12961295 def clear_input (self ):
12971296 """Useful for monitoring inputs. """
1298- self .input .value = jnp .zeros_like (self .input .value )
1297+ self .input .value = bm .zeros_like (self .input .value )
12991298
13001299
13011300class Channel (DynamicalSystem ):
0 commit comments