22
33import collections
44import gc
5- from typing import Union , Dict , Callable , Sequence , Optional , Tuple , Any
5+ import warnings
6+ from typing import Union , Dict , Callable , Sequence , Optional , Tuple
67
78import jax
89import jax .numpy as jnp
910import numpy as np
1011
11- from brainpy import tools , check
12+ from brainpy import tools
1213from brainpy ._src import math as bm
13- from brainpy ._src .math .ndarray import Variable , VariableView
14- from brainpy ._src .math .object_transform .base import BrainPyObject , Collector
1514from brainpy ._src .connect import TwoEndConnector , MatConn , IJConn , One2One , All2All
1615from brainpy ._src .initialize import Initializer , parameter , variable , Uniform , noise as init_noise
1716from brainpy ._src .integrators import odeint , sdeint
18- from brainpy .algorithms import OnlineAlgorithm , OfflineAlgorithm
17+ from brainpy ._src .math .ndarray import Variable , VariableView
18+ from brainpy ._src .math .object_transform .base import BrainPyObject , Collector
1919from brainpy .errors import NoImplementationError , UnsupportedError
2020from brainpy .types import ArrayType , Shape
2121
2222__all__ = [
2323 # general class
2424 'DynamicalSystem' ,
25+ 'Module' ,
2526 'FuncAsDynSys' ,
26- 'DSPartial' ,
2727
2828 # containers
2929 'Container' , 'Network' , 'Sequential' , 'System' ,
4848SLICE_VARS = 'slice_vars'
4949
5050
51+ def not_pass_shargs (func : Callable ):
52+ """Label the update function as the one without passing shared arguments.
53+
54+ The original update function explicitly requires shared arguments at the first place::
55+
56+ class TheModel(DynamicalSystem):
57+ def update(self, s, x):
58+ # s is the shared arguments, like `t`, `dt`, etc.
59+ pass
60+
61+ So, each time we call the model we should provide shared arguments into the model::
62+
63+ TheModel()(shared, inputs)
64+
65+ When we label the update function as ``do_not_pass_sha_args``, this time there is no
66+ need to call the dynamical system with shared arguments::
67+
68+ class NewModel(DynamicalSystem):
69+ @no_shared
70+ def update(self, x):
71+ pass
72+
73+ NewModel()(inputs)
74+
75+ .. versionadded:: 2.3.5
76+
77+ Parameters
78+ ----------
79+ func: Callable
80+ The function in the :py:class:`~.DynamicalSystem`.
81+
82+ Returns
83+ -------
84+ func: Callable
85+ The wrapped function for the class.
86+ """
87+ func ._new_style = True
88+ return func
89+
90+
5191class DynamicalSystem (BrainPyObject ):
5292 """Base Dynamical System class.
5393
@@ -65,7 +105,6 @@ class DynamicalSystem(BrainPyObject):
65105 we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
66106 :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
67107
68-
69108 Parameters
70109 ----------
71110 name : optional, str
@@ -74,12 +113,6 @@ class DynamicalSystem(BrainPyObject):
74113 The model computation mode. It should be instance of :py:class:`~.Mode`.
75114 """
76115
77- online_fit_by : Optional [OnlineAlgorithm ]
78- '''Online fitting method.'''
79-
80- offline_fit_by : Optional [OfflineAlgorithm ]
81- '''Offline fitting method.'''
82-
83116 global_delay_data : Dict [str , Tuple [Union [bm .LengthDelay , None ], Variable ]] = dict ()
84117 '''Global delay data, which stores the delay variables and corresponding delay targets.
85118 This variable is useful when the same target variable is used in multiple mappings,
@@ -97,15 +130,11 @@ def __init__(
97130 f'but we got { type (mode )} : { mode } ' )
98131 self ._mode = mode
99132
100- super (DynamicalSystem , self ).__init__ (name = name )
101-
102133 # local delay variables
103134 self .local_delay_vars : Dict [str , bm .LengthDelay ] = Collector ()
104135
105- # fitting parameters
106- self .online_fit_by = None
107- self .offline_fit_by = None
108- self .fit_record = dict ()
136+ # super initialization
137+ super (DynamicalSystem , self ).__init__ (name = name )
109138
110139 @property
111140 def mode (self ) -> bm .Mode :
@@ -124,7 +153,21 @@ def __repr__(self):
124153
125154 def __call__ (self , * args , ** kwargs ):
126155 """The shortcut to call ``update`` methods."""
127- return self .update (* args , ** kwargs )
156+ if hasattr (self .update , '_new_style' ) and getattr (self .update , '_new_style' ):
157+ if len (args ) and isinstance (args [0 ], dict ):
158+ bm .share .save_shargs (** args [0 ])
159+ return self .update (* args [1 :], ** kwargs )
160+ else :
161+ return self .update (* args , ** kwargs )
162+ else :
163+ if len (args ) and isinstance (args [0 ], dict ):
164+ return self .update (* args , ** kwargs )
165+ else :
166+ # If first argument is not shared argument,
167+ # we should get the shared arguments from the global context.
168+ # However, users should set and update shared arguments
169+ # in the global context when using this mode.
170+ return self .update (bm .share .get_shargs (), * args , ** kwargs )
128171
129172 def register_delay (
130173 self ,
@@ -339,26 +382,13 @@ def __del__(self):
339382 del self .__dict__ [key ]
340383 gc .collect ()
341384
342- @tools .not_customized
343- def online_init (self ):
344- raise NoImplementationError ('Subclass must implement online_init() function when using OnlineTrainer.' )
345-
346- @tools .not_customized
347- def online_fit (self ,
348- target : ArrayType ,
349- fit_record : Dict [str , ArrayType ]):
350- raise NoImplementationError ('Subclass must implement online_fit() function when using OnlineTrainer.' )
351-
352- @tools .not_customized
353- def offline_fit (self ,
354- target : ArrayType ,
355- fit_record : Dict [str , ArrayType ]):
356- raise NoImplementationError ('Subclass must implement offline_fit() function when using OfflineTrainer.' )
357-
358385 def clear_input (self ):
359386 pass
360387
361388
389+ Module = DynamicalSystem
390+
391+
362392class FuncAsDynSys (DynamicalSystem ):
363393 """Transform a Python function as a :py:class:`~.DynamicalSystem`
364394
@@ -411,31 +441,6 @@ def __repr__(self):
411441 f'{ indent } num_of_vars={ len (self .implicit_vars )} )' )
412442
413443
414- class DSPartial (FuncAsDynSys ):
415- def __init__ (
416- self ,
417- target : Callable ,
418- * args ,
419- child_objs : Union [Callable , BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]] = None ,
420- dyn_vars : Union [Variable , Sequence [Variable ], Dict [str , Variable ]] = None ,
421- shared : Dict = None ,
422- ** keywords
423- ):
424- super ().__init__ (target = target , child_objs = child_objs , dyn_vars = dyn_vars )
425-
426- check .is_dict_data (shared , all_none = True )
427- self .target = check .is_callable (target , )
428- self .args = tuple (args )
429- self .keywords = keywords
430- self .shared = dict () if shared is None else shared
431-
432- def __call__ (self , s , * args , ** keywords ):
433- assert isinstance (s , dict )
434- s = tools .DotDict (s ).update (self .shared )
435- args = self .args + (s ,) + args
436- keywords = {** self .keywords , ** keywords }
437- return self .target (* args , ** keywords )
438-
439444
440445class Container (DynamicalSystem ):
441446 """Container object which is designed to add other instances of DynamicalSystem.
@@ -639,7 +644,7 @@ def __repr__(self):
639644 entries = '\n ' .join (f' [{ i } ] { tools .repr_object (x )} ' for i , x in enumerate (self ._modules ))
640645 return f'{ self .__class__ .__name__ } (\n { entries } \n )'
641646
642- def update (self , * args ) -> ArrayType :
647+ def update (self , s , x ) -> ArrayType :
643648 """Update function of a sequential model.
644649
645650 Parameters
@@ -654,7 +659,6 @@ def update(self, *args) -> ArrayType:
654659 y: ArrayType
655660 The output tensor.
656661 """
657- s , x = (dict (), args [0 ]) if len (args ) == 1 else (args [0 ], args [1 ])
658662 for m in self ._modules :
659663 if isinstance (m , DynamicalSystem ):
660664 x = m (s , x )
@@ -818,7 +822,7 @@ def get_batch_shape(self, batch_size=None):
818822 else :
819823 return (batch_size ,) + self .varshape
820824
821- def update (self , tdi , x = None ):
825+ def update (self , * args ):
822826 """The function to specify the updating rule.
823827
824828 Parameters
0 commit comments