33import collections
44import gc
55import inspect
6+ import warnings
67from typing import Union , Dict , Callable , Sequence , Optional , Any
78
89import numpy as np
2829
2930SLICE_VARS = 'slice_vars'
3031
32+ _update_deprecate_msg = '''
33+ From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
34+
35+ Instead of using:
36+
37+ def update(self, tdi, *args, **kwagrs):
38+ ...
39+
40+ Please use:
41+
42+ def update(self, *args, **kwagrs):
43+ t = bp.share['t']
44+ ...
45+ '''
46+
3147
3248def not_pass_shared (func : Callable ):
3349 """Label the update function as the one without passing shared arguments.
@@ -160,13 +176,38 @@ def clear_input(self):
160176 pass
161177
162178 def step_run (self , i , * args , ** kwargs ):
179+ """The step run function.
180+
181+ This function can be directly applied to run the dynamical system.
182+ Particularly, ``i`` denotes the running index.
183+
184+ Args:
185+ i: The current running index.
186+ *args: The arguments of ``update()`` function.
187+ **kwargs: The arguments of ``update()`` function.
188+
189+ Returns:
190+ out: The update function returns.
191+ """
163192 global share
164193 if share is None :
165194 from brainpy ._src .context import share
166195 share .save (i = i , t = i * bm .dt )
167196 return self .update (* args , ** kwargs )
168197
169- jit_step_run = bm .cls_jit (step_run , inline = True )
198+ @bm .cls_jit (inline = True )
199+ def jit_step_run (self , i , * args , ** kwargs ):
200+ """The jitted step function for running.
201+
202+ Args:
203+ i: The current running index.
204+ *args: The arguments of ``update()`` function.
205+ **kwargs: The arguments of ``update()`` function.
206+
207+ Returns:
208+ out: The update function returns.
209+ """
210+ return self .step_run (i , * args , ** kwargs )
170211
171212 @property
172213 def mode (self ) -> bm .Mode :
@@ -189,32 +230,35 @@ def _compatible_update(self, *args, **kwargs):
189230
190231 if len (update_args ) and update_args [0 ].name in ['tdi' , 'sh' , 'sha' ]:
191232 if len (args ) > 0 :
192- if isinstance (args [0 ], dict ):
233+ if isinstance (args [0 ], dict ) and all ([ bm . isscalar ( v ) for v in args [ 0 ]. values ()]) :
193234 # define:
194235 # update(tdi, *args, **kwargs)
195236 # call:
196237 # update(tdi, *args, **kwargs)
197238 ret = update_fun (* args , ** kwargs )
198- # TODO: deprecation
239+ warnings . warn ( _update_deprecate_msg , UserWarning )
199240 else :
200241 # define:
201242 # update(tdi, *args, **kwargs)
202243 # call:
203244 # update(*args, **kwargs)
204245 ret = update_fun (share .get_shargs (), * args , ** kwargs )
246+ warnings .warn (_update_deprecate_msg , UserWarning )
205247 else :
206248 if update_args [0 ].name in kwargs :
207249 # define:
208250 # update(tdi, *args, **kwargs)
209251 # call:
210252 # update(tdi=??, **kwargs)
211253 ret = update_fun (** kwargs )
254+ warnings .warn (_update_deprecate_msg , UserWarning )
212255 else :
213256 # define:
214257 # update(tdi, *args, **kwargs)
215258 # call:
216259 # update(**kwargs)
217260 ret = update_fun (share .get_shargs (), * args , ** kwargs )
261+ warnings .warn (_update_deprecate_msg , UserWarning )
218262 return ret
219263
220264 try :
@@ -230,6 +274,7 @@ def _compatible_update(self, *args, **kwargs):
230274 # update(*args, **kwargs)
231275 share .save (** args [0 ])
232276 ret = update_fun (* args [1 :], ** kwargs )
277+ warnings .warn (_update_deprecate_msg , UserWarning )
233278 return ret
234279 else :
235280 # user define ``update()`` function which receives the shared argument,
@@ -240,6 +285,7 @@ def _compatible_update(self, *args, **kwargs):
240285 # as
241286 # update(tdi, *args, **kwargs)
242287 ret = update_fun (share .get_shargs (), * args , ** kwargs )
288+ warnings .warn (_update_deprecate_msg , UserWarning )
243289 return ret
244290 else :
245291 return update_fun (* args , ** kwargs )
0 commit comments