11# -*- coding: utf-8 -*-
22
3+ import inspect
34import math
45import time
6+ import warnings
57from typing import Callable , Union , Dict , Sequence , Tuple
68
79import jax .numpy as jnp
1416from brainpy import optim , losses
1517from brainpy ._src .analysis import utils , base , constants
1618from brainpy ._src .dynsys import DynamicalSystem
19+ from brainpy ._src .context import share
1720from brainpy ._src .runners import check_and_format_inputs , _f_ops
18- from brainpy ._src .tools .dicts import DotDict
1921from brainpy .errors import AnalyzerError , UnsupportedError
2022from brainpy .types import ArrayType
23+ from brainpy ._src .deprecations import _input_deprecate_msg
24+
2125
2226__all__ = [
2327 'SlowPointFinder' ,
@@ -123,7 +127,7 @@ def __init__(
123127 f_loss_batch : Callable = None ,
124128 fun_inputs : Callable = None ,
125129 ):
126- super (SlowPointFinder , self ).__init__ ()
130+ super ().__init__ ()
127131
128132 # static arguments
129133 if not isinstance (args , tuple ):
@@ -514,7 +518,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
514518 # Compute pairwise distances between all fixed points.
515519 distances = np .asarray (utils .euclidean_distance_jax (self .fixed_points , num_fps ))
516520
517- # Find second smallest element in each column of the pairwise distance matrix.
521+ # Find the second smallest element in each column of the pairwise distance matrix.
518522 # This corresponds to the closest neighbor for each fixed point.
519523 closest_neighbor = np .partition (distances , kth = 1 , axis = 0 )[1 ]
520524
@@ -636,11 +640,16 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
636640 'L' : L })
637641 return decompositions
638642
639- def _step_func_input (self , shared ):
643+ def _step_func_input (self ):
640644 if self ._inputs is None :
641645 return
642646 elif callable (self ._inputs ):
643- self ._inputs (shared )
647+ try :
648+ ba = inspect .signature (self ._inputs ).bind (dict ())
649+ self ._inputs (share .get_shargs ())
650+ warnings .warn (_input_deprecate_msg , UserWarning )
651+ except TypeError :
652+ self ._inputs ()
644653 else :
645654 for ops , values in self ._inputs ['fixed' ].items ():
646655 for var , data in values :
@@ -650,7 +659,7 @@ def _step_func_input(self, shared):
650659 raise UnsupportedError
651660 for ops , values in self ._inputs ['functional' ].items ():
652661 for var , data in values :
653- _f_ops (ops , var , data (shared ))
662+ _f_ops (ops , var , data (share . get_shargs () ))
654663 for ops , values in self ._inputs ['iterated' ].items ():
655664 if len (values ) > 0 :
656665 raise UnsupportedError
@@ -732,9 +741,10 @@ def _generate_ds_cell_function(
732741 ):
733742 if dt is None : dt = bm .get_dt ()
734743 if t is None : t = 0.
735- shared = DotDict (t = t , dt = dt , i = 0 )
736744
737745 def f_cell (h : Dict ):
746+ share .save (t = t , i = 0 , dt = dt )
747+
738748 # update target variables
739749 for k , v in self .target_vars .items ():
740750 v .value = (bm .asarray (h [k ], dtype = v .dtype )
@@ -747,11 +757,10 @@ def f_cell(h: Dict):
747757
748758 # add inputs
749759 target .clear_input ()
750- self ._step_func_input (shared )
760+ self ._step_func_input ()
751761
752762 # call update functions
753- args = (shared ,) + self .args
754- target (* args )
763+ target (* self .args )
755764
756765 # get new states
757766 new_h = {k : (v .value if (v .batch_axis is None ) else jnp .squeeze (v .value , axis = v .batch_axis ))
0 commit comments