1818from brainpy .integrators import odeint , sdeint
1919from brainpy .modes import Mode , TrainingMode , BatchingMode , normal , training
2020from brainpy .tools .others import to_size , size2num
21- from brainpy .types import Tensor , Shape
21+ from brainpy .types import Array , Shape
2222
2323__all__ = [
2424 # general class
@@ -70,17 +70,17 @@ def __init__(
7070 name : str = None ,
7171 mode : Optional [Mode ] = None ,
7272 ):
73- super (DynamicalSystem , self ).__init__ (name = name )
74-
75- # local delay variables
76- self .local_delay_vars : Dict [str , bm .LengthDelay ] = Collector ()
77-
7873 # mode setting
7974 if mode is None : mode = normal
8075 if not isinstance (mode , Mode ):
8176 raise ValueError (f'Should be instance of { Mode .__name__ } , but we got { type (Mode )} : { Mode } ' )
8277 self ._mode = mode
8378
79+ super (DynamicalSystem , self ).__init__ (name = name )
80+
81+ # local delay variables
82+ self .local_delay_vars : Dict [str , bm .LengthDelay ] = Collector ()
83+
8484 # fitting parameters
8585 self .online_fit_by = None
8686 self .offline_fit_by = None
@@ -106,9 +106,9 @@ def __call__(self, *args, **kwargs):
106106 def register_delay (
107107 self ,
108108 identifier : str ,
109- delay_step : Optional [Union [int , Tensor , Callable , Initializer ]],
109+ delay_step : Optional [Union [int , Array , Callable , Initializer ]],
110110 delay_target : bm .Variable ,
111- initial_delay_data : Union [Initializer , Callable , Tensor , float , int , bool ] = None ,
111+ initial_delay_data : Union [Initializer , Callable , Array , float , int , bool ] = None ,
112112 ):
113113 """Register delay variable.
114114
@@ -317,14 +317,14 @@ def offline_init(self):
317317
318318 @tools .not_customized
319319 def online_fit (self ,
320- target : Tensor ,
321- fit_record : Dict [str , Tensor ]):
320+ target : Array ,
321+ fit_record : Dict [str , Array ]):
322322 raise NoImplementationError ('Subclass must implement online_fit() function when using OnlineTrainer.' )
323323
324324 @tools .not_customized
325325 def offline_fit (self ,
326- target : Tensor ,
327- fit_record : Dict [str , Tensor ]):
326+ target : Array ,
327+ fit_record : Dict [str , Array ]):
328328 raise NoImplementationError ('Subclass must implement offline_fit() function when using OfflineTrainer.' )
329329
330330 def clear_input (self ):
@@ -482,7 +482,7 @@ def f(x):
482482 entries = '\n ' .join (f' [{ i } ] { f (x )} ' for i , x in enumerate (self ))
483483 return f'{ self .__class__ .__name__ } (\n { entries } \n )'
484484
485- def update (self , sha : dict , x : Any ) -> Tensor :
485+ def update (self , sha : dict , x : Any ) -> Array :
486486 """Update function of a sequential model.
487487
488488 Parameters
@@ -494,7 +494,7 @@ def update(self, sha: dict, x: Any) -> Tensor:
494494
495495 Returns
496496 -------
497- y: Tensor
497+ y: Array
498498 The output tensor.
499499 """
500500 for node in self .implicit_nodes .values ():
@@ -686,7 +686,7 @@ def __init__(
686686 self ,
687687 pre : NeuGroup ,
688688 post : NeuGroup ,
689- conn : Union [TwoEndConnector , Tensor , Dict [str , Tensor ]] = None ,
689+ conn : Union [TwoEndConnector , Array , Dict [str , Array ]] = None ,
690690 name : str = None ,
691691 mode : Mode = normal ,
692692 ):
@@ -904,7 +904,7 @@ def __init__(
904904 self ,
905905 pre : NeuGroup ,
906906 post : NeuGroup ,
907- conn : Union [TwoEndConnector , Tensor , Dict [str , Tensor ]] = None ,
907+ conn : Union [TwoEndConnector , Array , Dict [str , Array ]] = None ,
908908 output : SynOut = NullSynOut (),
909909 stp : SynSTP = NullSynSTP (),
910910 ltp : SynLTP = NullSynLTP (),
@@ -946,10 +946,10 @@ def __init__(
946946
947947 def init_weights (
948948 self ,
949- weight : Union [float , Tensor , Initializer , Callable ],
949+ weight : Union [float , Array , Initializer , Callable ],
950950 comp_method : str ,
951951 sparse_data : str = 'csr'
952- ) -> Union [float , Tensor ]:
952+ ) -> Union [float , Array ]:
953953 if comp_method not in ['sparse' , 'dense' ]:
954954 raise ValueError (f'"comp_method" must be in "sparse" and "dense", but we got { comp_method } ' )
955955 if sparse_data not in ['csr' , 'ij' ]:
@@ -1061,11 +1061,11 @@ def __init__(
10611061 self ,
10621062 size : Shape ,
10631063 keep_size : bool = False ,
1064- C : Union [float , Tensor , Initializer , Callable ] = 1. ,
1065- A : Union [float , Tensor , Initializer , Callable ] = 1e-3 ,
1066- V_th : Union [float , Tensor , Initializer , Callable ] = 0. ,
1067- V_initializer : Union [Initializer , Callable , Tensor ] = Uniform (- 70 , - 60. ),
1068- noise : Union [float , Tensor , Initializer , Callable ] = None ,
1064+ C : Union [float , Array , Initializer , Callable ] = 1. ,
1065+ A : Union [float , Array , Initializer , Callable ] = 1e-3 ,
1066+ V_th : Union [float , Array , Initializer , Callable ] = 0. ,
1067+ V_initializer : Union [Initializer , Callable , Array ] = Uniform (- 70 , - 60. ),
1068+ noise : Union [float , Array , Initializer , Callable ] = None ,
10691069 method : str = 'exp_auto' ,
10701070 name : str = None ,
10711071 mode : Mode = normal ,
0 commit comments