Skip to content

Commit 15ae3ae

Browse files
committed
updates
1 parent 520d828 commit 15ae3ae

File tree

8 files changed

+80
-48
lines changed

8 files changed

+80
-48
lines changed

brainpy/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
Network = DynSysGroup
6666
# delays
6767
from brainpy._src.delay import (
68-
VariDelay as VariDelay,
68+
VarDelay as VarDelay,
6969
)
7070

7171
# building blocks
@@ -129,12 +129,16 @@
129129
from brainpy._add_deprecations import deprecation_getattr2
130130

131131
__deprecations = {
132+
'Module': ('brainpy.Module', 'brainpy.DynamicalSystem', DynamicalSystem),
133+
'Channel': ('brainpy.Channel', 'brainpy.dyn.IonChannel', dyn.IonChannel),
134+
'NeuGroup': ('brainpy.NeuGroup', 'brainpy.dyn.NeuDyn', dyn.NeuDyn),
135+
'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
132136
'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
137+
133138
'optimizers': ('brainpy.optimizers', 'brainpy.optim', optim),
134139
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
135140
'SynSTP': ('brainpy.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
136141
'SynOut': ('brainpy.SynOut', 'brainpy.synapses.SynOut', synapses.SynOut),
137-
'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
138142
'TwoEndConn': ('brainpy.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
139143
'CondNeuGroup': ('brainpy.CondNeuGroup', 'brainpy.syn.CondNeuGroup', dyn.CondNeuGroup),
140144
}

brainpy/_src/delay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
__all__ = [
2323
'Delay',
24-
'VariDelay',
24+
'VarDelay',
2525
'DataDelay',
2626
'DelayAccess',
2727
]
@@ -432,7 +432,7 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
432432
return sharding
433433

434434

435-
class VariDelay(Delay):
435+
class VarDelay(Delay):
436436
"""Generate Delays for the given :py:class:`~.Variable` instance.
437437
438438
The data in this delay variable is arranged as::
@@ -690,7 +690,7 @@ def _init_data(self, length: int, batch_size: int = None):
690690
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
691691

692692

693-
class DataDelay(VariDelay):
693+
class DataDelay(VarDelay):
694694
not_desc_params = ('time', 'entries')
695695

696696
def __init__(

brainpy/_src/deprecations.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,40 @@
88
]
99

1010

11+
_update_deprecate_msg = '''
12+
From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
13+
14+
Instead of using:
15+
16+
def update(self, tdi, *args, **kwagrs):
17+
t = tdi['t']
18+
...
19+
20+
Please use:
21+
22+
def update(self, *args, **kwagrs):
23+
t = bp.share['t']
24+
...
25+
'''
26+
27+
28+
_input_deprecate_msg = '''
29+
From brainpy>=2.4.3, input() function no longer needs to receive a global shared argument.
30+
31+
Instead of using:
32+
33+
def input(tdi):
34+
...
35+
36+
Please use:
37+
38+
def input():
39+
t = bp.share['t']
40+
...
41+
'''
42+
43+
44+
1145
def _deprecate(msg):
1246
warnings.simplefilter('always', DeprecationWarning) # turn off filter
1347
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)

brainpy/_src/dyn/projections/aligns.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax
44

55
from brainpy import math as bm, check
6-
from brainpy._src.delay import Delay, VariDelay, DataDelay, DelayAccess
6+
from brainpy._src.delay import Delay, VarDelay, DataDelay, DelayAccess
77
from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic, Sequential
88
from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
99

@@ -54,7 +54,7 @@ def update(self):
5454

5555
def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay:
5656
if isinstance(info, bm.Variable):
57-
return VariDelay(info)
57+
return VarDelay(info)
5858
elif isinstance(info, ReturnInfo):
5959
if isinstance(info.batch_or_mode, int):
6060
shape = (info.batch_or_mode,) + tuple(info.size)
@@ -106,7 +106,7 @@ def __init__(self):
106106
super().__init__()
107107
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
108108
V_initializer=bp.init.Normal(-55., 2.))
109-
self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
109+
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
110110
self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
111111
self.syn2 = bp.dyn.Expon(size=800, tau=10.)
112112
self.E = bp.dyn.VanillaProj(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
@@ -180,7 +180,7 @@ def __init__(self):
180180
super().__init__()
181181
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
182182
V_initializer=bp.init.Normal(-55., 2.))
183-
self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
183+
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
184184
self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
185185
syn=bp.dyn.Expon.desc(size=4000, tau=5.),
186186
out=bp.dyn.COBA.desc(E=0.),
@@ -374,7 +374,7 @@ def __init__(self):
374374
super().__init__()
375375
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
376376
V_initializer=bp.init.Normal(-55., 2.))
377-
self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
377+
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
378378
self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
379379
syn=bp.dyn.Expon(size=4000, tau=5.),
380380
out=bp.dyn.COBA(E=0.),

brainpy/_src/dynsys.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from brainpy._src.mixin import AutoDelaySupp, Container, DelayRegister, global_delay_data
1414
from brainpy.errors import NoImplementationError, UnsupportedError
1515
from brainpy.types import ArrayType, Shape
16+
from brainpy._src.deprecations import _update_deprecate_msg
1617

1718
share = None
1819

@@ -29,21 +30,6 @@
2930

3031
SLICE_VARS = 'slice_vars'
3132

32-
_update_deprecate_msg = '''
33-
From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
34-
35-
Instead of using:
36-
37-
def update(self, tdi, *args, **kwagrs):
38-
...
39-
40-
Please use:
41-
42-
def update(self, *args, **kwagrs):
43-
t = bp.share['t']
44-
...
45-
'''
46-
4733

4834
def not_pass_shared(func: Callable):
4935
"""Label the update function as the one without passing shared arguments.
@@ -305,14 +291,14 @@ def __repr__(self):
305291
def __call__(self, *args, **kwargs):
306292
"""The shortcut to call ``update`` methods."""
307293

308-
# update ``before_updates``
294+
# ``before_updates``
309295
for model in self.before_updates.values():
310296
model()
311297

312298
# update the model self
313299
ret = self.update(*args, **kwargs)
314300

315-
# update ``after_updates``
301+
# ``after_updates``
316302
for model in self.after_updates.values():
317303
model(ret)
318304
return ret

brainpy/_src/math/ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def split(self, indices_or_sections, axis=0):
748748
sub-arrays : list of ndarrays
749749
A list of sub-arrays as views into `ary`.
750750
"""
751-
return [_return(a) for a in self.value.split(indices_or_sections, axis=axis)]
751+
return [_return(a) for a in jnp.split(self.value, indices_or_sections, axis=axis)]
752752

753753
def take(self, indices, axis=None, mode=None):
754754
"""Return an array formed from the elements of a at the given indices."""

brainpy/_src/math/object_transform/controls.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def _get_for_loop_transform(
722722
progress_bar: bool,
723723
remat: bool,
724724
reverse: bool,
725-
unroll: int
725+
unroll: int,
726726
):
727727
def fun2scan(carry, x):
728728
for k in dyn_vars.keys():
@@ -753,6 +753,7 @@ def for_loop(
753753
remat: bool = False,
754754
jit: Optional[bool] = None,
755755
progress_bar: bool = False,
756+
unroll_kwargs: Optional[Dict] = None,
756757

757758
# deprecated
758759
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
@@ -845,6 +846,8 @@ def for_loop(
845846
.. deprecated:: 2.4.0
846847
No longer need to provide ``child_objs``. This function is capable of automatically
847848
collecting the children objects used in the target ``func``.
849+
unroll_kwargs: dict
850+
The keyword arguments without unrolling.
848851
849852
Returns
850853
-------
@@ -855,6 +858,9 @@ def for_loop(
855858
dynvar_deprecation(dyn_vars)
856859
node_deprecation(child_objs)
857860

861+
if unroll_kwargs is None:
862+
unroll_kwargs = dict()
863+
858864
if not isinstance(operands, (list, tuple)):
859865
operands = (operands,)
860866

@@ -885,7 +891,9 @@ def for_loop(
885891
dyn_vars = VariableStack()
886892

887893
# TODO: cache mechanism?
888-
transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
894+
transform = _get_for_loop_transform(body_fun, dyn_vars, bar,
895+
progress_bar, remat, reverse,
896+
unroll)
889897
if jit:
890898
dyn_vals, out_vals = transform(operands)
891899
else:

brainpy/dyn/channels.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
from brainpy._src.dyn.channels.base import (
2-
IonChannel,
2+
IonChannel as IonChannel,
33
)
44

55
from brainpy._src.dyn.channels.calcium import (
6-
CalciumChannel,
7-
ICaN_IS2008,
8-
ICaT_HM1992,
9-
ICaT_HP1992,
10-
ICaHT_HM1992,
11-
ICaHT_Re1993,
12-
ICaL_IS2008,
6+
CalciumChannel as CalciumChannel,
7+
ICaN_IS2008 as ICaN_IS2008,
8+
ICaT_HM1992 as ICaT_HM1992,
9+
ICaT_HP1992 as ICaT_HP1992,
10+
ICaHT_HM1992 as ICaHT_HM1992,
11+
ICaHT_Re1993 as ICaHT_Re1993,
12+
ICaL_IS2008 as ICaL_IS2008,
1313
)
1414

1515

1616
from brainpy._src.dyn.channels.potassium import (
17-
PotassiumChannel,
18-
IKDR_Ba2002v2,
19-
IK_TM1991v2,
20-
IK_HH1952v2,
21-
IKA1_HM1992v2,
22-
IKA2_HM1992v2,
23-
IKK2A_HM1992v2,
24-
IKK2B_HM1992v2,
25-
IKNI_Ya1989v2,
26-
IK_Leak,
17+
PotassiumChannel as PotassiumChannel,
18+
IKDR_Ba2002v2 as IKDR_Ba2002v2,
19+
IK_TM1991v2 as IK_TM1991v2,
20+
IK_HH1952v2 as IK_HH1952v2,
21+
IKA1_HM1992v2 as IKA1_HM1992v2,
22+
IKA2_HM1992v2 as IKA2_HM1992v2,
23+
IKK2A_HM1992v2 as IKK2A_HM1992v2,
24+
IKK2B_HM1992v2 as IKK2B_HM1992v2,
25+
IKNI_Ya1989v2 as IKNI_Ya1989v2,
26+
IK_Leak as IK_Leak,
2727
)
2828
from brainpy._src.dyn.channels.potassium_compatible import (
2929
IKDR_Ba2002,

0 commit comments

Comments
 (0)