Skip to content

Commit 0a519c0

Browse files
authored
Merge pull request #343 from chaoming0625/master
Fix bug and more surrogate grad function supports
2 parents d8d23db + 1480220 commit 0a519c0

File tree

6 files changed

+224
-64
lines changed

6 files changed

+224
-64
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.3.5"
3+
__version__ = "2.3.6"
44

55

66
# fundamental supporting modules
@@ -75,8 +75,7 @@
7575
TwoEndConn as TwoEndConn,
7676
CondNeuGroup as CondNeuGroup,
7777
Channel as Channel)
78-
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
79-
LoopOverTime as LoopOverTime,)
78+
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
8079
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
8180
from brainpy._src.dyn.context import share, Delay
8281

@@ -207,7 +206,6 @@
207206
dyn.__dict__['TwoEndConn'] = TwoEndConn
208207
dyn.__dict__['CondNeuGroup'] = CondNeuGroup
209208
dyn.__dict__['Channel'] = Channel
210-
dyn.__dict__['NoSharedArg'] = NoSharedArg
211209
dyn.__dict__['LoopOverTime'] = LoopOverTime
212210
dyn.__dict__['DSRunner'] = DSRunner
213211

brainpy/_src/dyn/transform.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
__all__ = [
1616
'LoopOverTime',
17-
'NoSharedArg',
1817
]
1918

2019

@@ -207,12 +206,13 @@ def __call__(
207206
if isinstance(duration_or_xs, float):
208207
shared = tools.DotDict()
209208
if self.t0 is not None:
210-
shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt)
209+
shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value
211210
if self.i0 is not None:
212-
shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0])
211+
shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value
213212
xs = None
214213
if self.no_state:
215214
raise ValueError('Under the `no_state=True` setting, input cannot be a duration.')
215+
length = shared['t'].shape
216216

217217
else:
218218
inp_err_msg = ('\n'
@@ -278,8 +278,8 @@ def __call__(
278278

279279
else:
280280
shared = tools.DotDict()
281-
shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt)
282-
shared['i'] = jnp.arange(self.i0.value, length[0])
281+
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
282+
shared['i'] = jnp.arange(0, length[0]) + self.i0.value
283283

284284
assert not self.no_state
285285
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
@@ -295,6 +295,10 @@ def __call__(
295295

296296
def reset_state(self, batch_size=None):
297297
self.target.reset_state(batch_size)
298+
if self.i0 is not None:
299+
self.i0.value = jnp.asarray(0)
300+
if self.t0 is not None:
301+
self.t0.value = jnp.asarray(0.)
298302

299303
def _run(self, static_sh, dyn_sh, x):
300304
share.save(**static_sh, **dyn_sh)
@@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x):
304308
self.target.clear_input()
305309
return outs
306310

307-
308-
class NoSharedArg(DynSysToBPObj):
309-
"""Transform an instance of :py:class:`~.DynamicalSystem` into a callable
310-
:py:class:`~.BrainPyObject` :math:`y=f(x)`.
311-
312-
.. note::
313-
314-
This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.
315-
316-
If some children nodes need shared arguments, like :py:class:`~.Dropout` or
317-
:py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors.
318-
319-
Examples
320-
--------
321-
322-
>>> import brainpy as bp
323-
>>> import brainpy.math as bm
324-
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
325-
>>> bm.relu,
326-
>>> bp.layers.Dense(10, 2))
327-
>>> l = bp.NoSharedArg(l)
328-
>>> l(bm.random.random(256, 100))
329-
330-
Parameters
331-
----------
332-
target: DynamicalSystem
333-
The target to transform.
334-
name: str
335-
The transformed object name.
336-
"""
337-
338-
def __init__(self, target: DynamicalSystem, name: str = None):
339-
super().__init__(target=target, name=name)
340-
if isinstance(target, Sequential) and target.no_shared_arg:
341-
raise ValueError(f'It is a {Sequential.__name__} object with `no_shared_arg=True`, '
342-
f'which has already able to be called with `f(x)`. ')
343-
344-
def __call__(self, *args, **kwargs):
345-
return self.target(tools.DotDict(), *args, **kwargs)
346-
347-
def reset(self, batch_size=None):
348-
"""Reset function which reset the whole variables in the model.
349-
"""
350-
self.target.reset(batch_size)
351-
352-
def reset_state(self, batch_size=None):
353-
self.target.reset_state(batch_size)

brainpy/_src/math/environment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def clone(self):
324324
return self.__class__()
325325

326326

327-
def set_environment(
327+
def set(
328328
mode: modes.Mode = None,
329329
dt: float = None,
330330
x64: bool = None,
@@ -381,6 +381,9 @@ def set_environment(
381381
set_complex(complex_)
382382

383383

384+
set_environment = set
385+
386+
384387
class environment(_DecoratorContextManager):
385388
r"""Context-manager that sets a computing environment for brain dynamics computation.
386389

0 commit comments

Comments
 (0)