Skip to content

Commit 329b6e7

Browse files
committed
add update() deprecation warning
1 parent d9a737b commit 329b6e7

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from brainpy import optim, losses
1515
from brainpy._src.analysis import utils, base, constants
1616
from brainpy._src.dynsys import DynamicalSystem
17+
from brainpy._src.context import share
1718
from brainpy._src.runners import check_and_format_inputs, _f_ops
18-
from brainpy._src.tools.dicts import DotDict
1919
from brainpy.errors import AnalyzerError, UnsupportedError
2020
from brainpy.types import ArrayType
2121

@@ -123,7 +123,7 @@ def __init__(
123123
f_loss_batch: Callable = None,
124124
fun_inputs: Callable = None,
125125
):
126-
super(SlowPointFinder, self).__init__()
126+
super().__init__()
127127

128128
# static arguments
129129
if not isinstance(args, tuple):
@@ -636,11 +636,11 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
636636
'L': L})
637637
return decompositions
638638

639-
def _step_func_input(self, shared):
639+
def _step_func_input(self):
640640
if self._inputs is None:
641641
return
642642
elif callable(self._inputs):
643-
self._inputs(shared)
643+
self._inputs(share.get_shargs())
644644
else:
645645
for ops, values in self._inputs['fixed'].items():
646646
for var, data in values:
@@ -650,7 +650,7 @@ def _step_func_input(self, shared):
650650
raise UnsupportedError
651651
for ops, values in self._inputs['functional'].items():
652652
for var, data in values:
653-
_f_ops(ops, var, data(shared))
653+
_f_ops(ops, var, data(share.get_shargs()))
654654
for ops, values in self._inputs['iterated'].items():
655655
if len(values) > 0:
656656
raise UnsupportedError
@@ -732,9 +732,10 @@ def _generate_ds_cell_function(
732732
):
733733
if dt is None: dt = bm.get_dt()
734734
if t is None: t = 0.
735-
shared = DotDict(t=t, dt=dt, i=0)
736735

737736
def f_cell(h: Dict):
737+
share.save(t=t, i=0, dt=dt)
738+
738739
# update target variables
739740
for k, v in self.target_vars.items():
740741
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -747,11 +748,10 @@ def f_cell(h: Dict):
747748

748749
# add inputs
749750
target.clear_input()
750-
self._step_func_input(shared)
751+
self._step_func_input()
751752

752753
# call update functions
753-
args = (shared,) + self.args
754-
target(*args)
754+
target(*self.args)
755755

756756
# get new states
757757
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))

brainpy/_src/dynsys.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import gc
55
import inspect
6+
import warnings
67
from typing import Union, Dict, Callable, Sequence, Optional, Any
78

89
import numpy as np
@@ -28,6 +29,21 @@
2829

2930
SLICE_VARS = 'slice_vars'
3031

32+
_update_deprecate_msg = '''
33+
From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
34+
35+
Instead of using:
36+
37+
def update(self, tdi, *args, **kwagrs):
38+
...
39+
40+
Please use:
41+
42+
def update(self, *args, **kwagrs):
43+
t = bp.share['t']
44+
...
45+
'''
46+
3147

3248
def not_pass_shared(func: Callable):
3349
"""Label the update function as the one without passing shared arguments.
@@ -160,13 +176,38 @@ def clear_input(self):
160176
pass
161177

162178
def step_run(self, i, *args, **kwargs):
179+
"""The step run function.
180+
181+
This function can be directly applied to run the dynamical system.
182+
Particularly, ``i`` denotes the running index.
183+
184+
Args:
185+
i: The current running index.
186+
*args: The arguments of ``update()`` function.
187+
**kwargs: The arguments of ``update()`` function.
188+
189+
Returns:
190+
out: The update function returns.
191+
"""
163192
global share
164193
if share is None:
165194
from brainpy._src.context import share
166195
share.save(i=i, t=i * bm.dt)
167196
return self.update(*args, **kwargs)
168197

169-
jit_step_run = bm.cls_jit(step_run, inline=True)
198+
@bm.cls_jit(inline=True)
199+
def jit_step_run(self, i, *args, **kwargs):
200+
"""The jitted step function for running.
201+
202+
Args:
203+
i: The current running index.
204+
*args: The arguments of ``update()`` function.
205+
**kwargs: The arguments of ``update()`` function.
206+
207+
Returns:
208+
out: The update function returns.
209+
"""
210+
return self.step_run(i, *args, **kwargs)
170211

171212
@property
172213
def mode(self) -> bm.Mode:
@@ -189,32 +230,35 @@ def _compatible_update(self, *args, **kwargs):
189230

190231
if len(update_args) and update_args[0].name in ['tdi', 'sh', 'sha']:
191232
if len(args) > 0:
192-
if isinstance(args[0], dict):
233+
if isinstance(args[0], dict) and all([bm.isscalar(v) for v in args[0].values()]):
193234
# define:
194235
# update(tdi, *args, **kwargs)
195236
# call:
196237
# update(tdi, *args, **kwargs)
197238
ret = update_fun(*args, **kwargs)
198-
# TODO: deprecation
239+
warnings.warn(_update_deprecate_msg, UserWarning)
199240
else:
200241
# define:
201242
# update(tdi, *args, **kwargs)
202243
# call:
203244
# update(*args, **kwargs)
204245
ret = update_fun(share.get_shargs(), *args, **kwargs)
246+
warnings.warn(_update_deprecate_msg, UserWarning)
205247
else:
206248
if update_args[0].name in kwargs:
207249
# define:
208250
# update(tdi, *args, **kwargs)
209251
# call:
210252
# update(tdi=??, **kwargs)
211253
ret = update_fun(**kwargs)
254+
warnings.warn(_update_deprecate_msg, UserWarning)
212255
else:
213256
# define:
214257
# update(tdi, *args, **kwargs)
215258
# call:
216259
# update(**kwargs)
217260
ret = update_fun(share.get_shargs(), *args, **kwargs)
261+
warnings.warn(_update_deprecate_msg, UserWarning)
218262
return ret
219263

220264
try:
@@ -230,6 +274,7 @@ def _compatible_update(self, *args, **kwargs):
230274
# update(*args, **kwargs)
231275
share.save(**args[0])
232276
ret = update_fun(*args[1:], **kwargs)
277+
warnings.warn(_update_deprecate_msg, UserWarning)
233278
return ret
234279
else:
235280
# user define ``update()`` function which receives the shared argument,
@@ -240,6 +285,7 @@ def _compatible_update(self, *args, **kwargs):
240285
# as
241286
# update(tdi, *args, **kwargs)
242287
ret = update_fun(share.get_shargs(), *args, **kwargs)
288+
warnings.warn(_update_deprecate_msg, UserWarning)
243289
return ret
244290
else:
245291
return update_fun(*args, **kwargs)

brainpy/dyn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
from .projections import *
88
from .others import *
99
from .outs import *
10+
from .rates import *
1011
from .compat import NeuGroup

0 commit comments

Comments
 (0)