1313import brainpy .math as bm
1414from brainpy import optimizers as optim , losses
1515from brainpy .analysis import utils , base , constants
16- from brainpy .base import TensorCollector
16+ from brainpy .base import ArrayCollector
1717from brainpy .dyn .base import DynamicalSystem
1818from brainpy .dyn .runners import build_inputs , check_and_format_inputs
1919from brainpy .errors import AnalyzerError , UnsupportedError
2020from brainpy .tools .others .dicts import DotDict
21- from brainpy .types import Array
21+ from brainpy .types import ArrayType
2222
2323__all__ = [
2424 'SlowPointFinder' ,
@@ -136,11 +136,11 @@ def __init__(
136136
137137 # update function
138138 if target_vars is None :
139- self .target_vars = TensorCollector ()
139+ self .target_vars = ArrayCollector ()
140140 else :
141141 if not isinstance (target_vars , dict ):
142142 raise TypeError (f'"target_vars" must be a dict but we got { type (target_vars )} ' )
143- self .target_vars = TensorCollector (target_vars )
143+ self .target_vars = ArrayCollector (target_vars )
144144 excluded_vars = () if excluded_vars is None else excluded_vars
145145 if isinstance (excluded_vars , dict ):
146146 excluded_vars = tuple (excluded_vars .values ())
@@ -295,7 +295,7 @@ def selected_ids(self, val):
295295
296296 def find_fps_with_gd_method (
297297 self ,
298- candidates : Union [Array , Dict [str , Array ]],
298+ candidates : Union [ArrayType , Dict [str , ArrayType ]],
299299 tolerance : Union [float , Dict [str , float ]] = 1e-5 ,
300300 num_batch : int = 100 ,
301301 num_opt : int = 10000 ,
@@ -305,7 +305,7 @@ def find_fps_with_gd_method(
305305
306306 Parameters
307307 ----------
308- candidates : Array , dict
308+ candidates : ArrayType , dict
309309 The array with the shape of (batch size, state dim) of hidden states
310310 of RNN to start training for fixed points.
311311
@@ -335,7 +335,7 @@ def find_fps_with_gd_method(
335335 # set up optimization
336336 num_candidate = self ._check_candidates (candidates )
337337 if not (isinstance (candidates , (bm .ndarray , jnp .ndarray , np .ndarray )) or isinstance (candidates , dict )):
338- raise ValueError ('Candidates must be instance of Array or dict of Array .' )
338+ raise ValueError ('Candidates must be instance of ArrayType or dict of ArrayType .' )
339339 fixed_points = tree_map (lambda a : bm .TrainVar (a ), candidates , is_leaf = lambda x : isinstance (x , bm .Array ))
340340 f_eval_loss = self ._get_f_eval_loss ()
341341
@@ -401,14 +401,14 @@ def batch_train(start_i, n_batch):
401401
402402 def find_fps_with_opt_solver (
403403 self ,
404- candidates : Union [Array , Dict [str , Array ]],
404+ candidates : Union [ArrayType , Dict [str , ArrayType ]],
405405 opt_solver : str = 'BFGS'
406406 ):
407407 """Optimize fixed points with nonlinear optimization solvers.
408408
409409 Parameters
410410 ----------
411- candidates: Array , dict
411+ candidates: ArrayType , dict
412412 The candidate (initial) fixed points.
413413 opt_solver: str
414414 The solver of the optimization.
@@ -535,7 +535,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
535535
536536 def compute_jacobians (
537537 self ,
538- points : Union [Array , Dict [str , Array ]],
538+ points : Union [ArrayType , Dict [str , ArrayType ]],
539539 stack_dict_var : bool = True ,
540540 plot : bool = False ,
541541 num_col : int = 4 ,
@@ -546,7 +546,7 @@ def compute_jacobians(
546546
547547 Parameters
548548 ----------
549- points: np.ndarray, bm.Array , jax.ndarray
549+ points: np.ndarray, bm.ArrayType , jax.ndarray
550550 The fixed points with the shape of (num_point, num_dim).
551551 stack_dict_var: bool
552552 Stack dictionary variables to calculate Jacobian matrix?
@@ -606,7 +606,7 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
606606
607607 Parameters
608608 ----------
609- matrices: np.ndarray, bm.Array , jax.ndarray
609+ matrices: np.ndarray, bm.ArrayType , jax.ndarray
610610 A 3D array with the shape of (num_matrices, dim, dim).
611611 sort_by: str
612612 The method of sorting.
0 commit comments