Skip to content

Commit 2d9ab53

Browse files
committed
fix LoopOverTime bug
1 parent f21aff0 commit 2d9ab53

File tree

2 files changed

+11
-56
lines changed

2 files changed

+11
-56
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.3.5"
3+
__version__ = "2.3.6"
44

55

66
# fundamental supporting modules
@@ -75,8 +75,7 @@
7575
TwoEndConn as TwoEndConn,
7676
CondNeuGroup as CondNeuGroup,
7777
Channel as Channel)
78-
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
79-
LoopOverTime as LoopOverTime,)
78+
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
8079
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
8180
from brainpy._src.dyn.context import share, Delay
8281

@@ -207,7 +206,6 @@
207206
dyn.__dict__['TwoEndConn'] = TwoEndConn
208207
dyn.__dict__['CondNeuGroup'] = CondNeuGroup
209208
dyn.__dict__['Channel'] = Channel
210-
dyn.__dict__['NoSharedArg'] = NoSharedArg
211209
dyn.__dict__['LoopOverTime'] = LoopOverTime
212210
dyn.__dict__['DSRunner'] = DSRunner
213211

brainpy/_src/dyn/transform.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
__all__ = [
1616
'LoopOverTime',
17-
'NoSharedArg',
1817
]
1918

2019

@@ -207,12 +206,13 @@ def __call__(
207206
if isinstance(duration_or_xs, float):
208207
shared = tools.DotDict()
209208
if self.t0 is not None:
210-
shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt)
209+
shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value
211210
if self.i0 is not None:
212-
shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0])
211+
shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value
213212
xs = None
214213
if self.no_state:
215214
raise ValueError('Under the `no_state=True` setting, input cannot be a duration.')
215+
length = shared['t'].shape
216216

217217
else:
218218
inp_err_msg = ('\n'
@@ -278,8 +278,8 @@ def __call__(
278278

279279
else:
280280
shared = tools.DotDict()
281-
shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt)
282-
shared['i'] = jnp.arange(self.i0.value, length[0])
281+
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
282+
shared['i'] = jnp.arange(0, length[0]) + self.i0.value
283283

284284
assert not self.no_state
285285
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
@@ -295,6 +295,10 @@ def __call__(
295295

296296
def reset_state(self, batch_size=None):
297297
self.target.reset_state(batch_size)
298+
if self.i0 is not None:
299+
self.i0.value = jnp.asarray(0)
300+
if self.t0 is not None:
301+
self.t0.value = jnp.asarray(0.)
298302

299303
def _run(self, static_sh, dyn_sh, x):
300304
share.save(**static_sh, **dyn_sh)
@@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x):
304308
self.target.clear_input()
305309
return outs
306310

307-
308-
class NoSharedArg(DynSysToBPObj):
309-
"""Transform an instance of :py:class:`~.DynamicalSystem` into a callable
310-
:py:class:`~.BrainPyObject` :math:`y=f(x)`.
311-
312-
.. note::
313-
314-
This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.
315-
316-
If some children nodes need shared arguments, like :py:class:`~.Dropout` or
317-
:py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors.
318-
319-
Examples
320-
--------
321-
322-
>>> import brainpy as bp
323-
>>> import brainpy.math as bm
324-
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
325-
>>> bm.relu,
326-
>>> bp.layers.Dense(10, 2))
327-
>>> l = bp.NoSharedArg(l)
328-
>>> l(bm.random.random(256, 100))
329-
330-
Parameters
331-
----------
332-
target: DynamicalSystem
333-
The target to transform.
334-
name: str
335-
The transformed object name.
336-
"""
337-
338-
def __init__(self, target: DynamicalSystem, name: str = None):
339-
super().__init__(target=target, name=name)
340-
if isinstance(target, Sequential) and target.no_shared_arg:
341-
raise ValueError(f'It is a {Sequential.__name__} object with `no_shared_arg=True`, '
342-
f'which has already able to be called with `f(x)`. ')
343-
344-
def __call__(self, *args, **kwargs):
345-
return self.target(tools.DotDict(), *args, **kwargs)
346-
347-
def reset(self, batch_size=None):
348-
"""Reset function which reset the whole variables in the model.
349-
"""
350-
self.target.reset(batch_size)
351-
352-
def reset_state(self, batch_size=None):
353-
self.target.reset_state(batch_size)

0 commit comments

Comments
 (0)