Skip to content

Commit 96a0652

Browse files
committed
fix bugs in delay vars
1 parent b25d717 commit 96a0652

File tree

7 files changed

+26
-25
lines changed

7 files changed

+26
-25
lines changed

brainpy/_src/dyn/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
135135

136136
# super initialization
137-
super(DynamicalSystem, self).__init__(name=name)
137+
BrainPyObject.__init__(self, name=name)
138138

139139
@property
140140
def mode(self) -> bm.Mode:
@@ -155,7 +155,8 @@ def __call__(self, *args, **kwargs):
155155
"""The shortcut to call ``update`` methods."""
156156
if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
157157
if len(args) and isinstance(args[0], dict):
158-
bm.share.save(**args[0])
158+
for k, v in args[0].items():
159+
bm.share.save(k, v)
159160
return self.update(*args[1:], **kwargs)
160161
else:
161162
return self.update(*args, **kwargs)

brainpy/_src/dyn/runners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,8 @@ def _step_func_predict(self, shared_args, t, i, x):
615615
# input step
616616
shared = tools.DotDict(t=t, i=i, dt=self.dt)
617617
shared.update(shared_args)
618-
bm.share.save(**shared)
618+
for k, v in shared.items():
619+
bm.share.save(k, v)
619620
self.target.clear_input()
620621
self._step_func_input(shared)
621622

@@ -630,7 +631,6 @@ def _step_func_predict(self, shared_args, t, i, x):
630631
# finally
631632
if self.progress_bar:
632633
id_tap(lambda *arg: self._pbar.update(), ())
633-
bm.share.remove_shargs()
634634
return out, mon
635635

636636
def _get_f_predict(self, shared_args: Dict = None):

brainpy/_src/experimental/delay.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ def __init__(
2121
length: int = 0,
2222
before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None,
2323
entries: Optional[Dict] = None,
24-
method: str = None,
24+
method: str = ROTATE_UPDATE,
2525
mode: bm.Mode = None,
2626
name: str = None,
2727
):
28+
DynamicalSystem.__init__(self, mode=mode)
2829
if method is None:
2930
if self.mode.is_a(bm.NonBatchingMode):
3031
method = ROTATE_UPDATE
3132
elif self.mode.is_parent_of(bm.TrainingMode):
3233
method = CONCAT_UPDATE
3334
else:
3435
method = ROTATE_UPDATE
35-
DynamicalSystem.__init__(self, mode=mode)
3636
DelayVariable.__init__(self,
3737
target=target,
3838
length=length,

brainpy/_src/math/context.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,24 @@ def save(self, identifier: str, data: Any) -> None:
6464
"""Save shared arguments in the global context."""
6565
assert isinstance(identifier, str)
6666

67-
if identifier in self._identifiers:
68-
raise ValueError(f'{identifier} has been used. Please assign another name.')
6967
if isinstance(data, DelayVariable):
68+
if identifier in self._identifiers:
69+
raise ValueError(f'{identifier} has been used. Please assign another name.')
7070
self._delays[identifier] = data
71-
elif isinstance(data, DelayEntry):
72-
if isinstance(data.target, DelayVariable):
73-
delay_key = f'delay{id(data)}'
74-
self.save(delay_key, data.target)
75-
delay = data.target
76-
elif isinstance(data.target, str):
77-
if data.target not in self._delays:
78-
raise ValueError(f'Delay target {data.target} has not been registered.')
79-
delay = self._delays[data.target]
80-
delay_key = data.target
81-
else:
82-
raise ValueError(f'Unknown delay target. {type(data.target)}')
83-
delay.register_entry(identifier, delay_time=data.time, delay_step=data.step)
84-
self._delay_entries[identifier] = delay_key
71+
# elif isinstance(data, DelayEntry):
72+
# if isinstance(data.target, DelayVariable):
73+
# delay_key = f'delay{id(data)}'
74+
# self.save(delay_key, data.target)
75+
# delay = data.target
76+
# elif isinstance(data.target, str):
77+
# if data.target not in self._delays:
78+
# raise ValueError(f'Delay target {data.target} has not been registered.')
79+
# delay = self._delays[data.target]
80+
# delay_key = data.target
81+
# else:
82+
# raise ValueError(f'Unknown delay target. {type(data.target)}')
83+
# delay.register_entry(identifier, delay_time=data.time, delay_step=data.step)
84+
# self._delay_entries[identifier] = delay_key
8585
else:
8686
self._arguments[identifier] = data
8787
self._identifiers.add(identifier)

brainpy/_src/math/delayvars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def __init__(
519519
name: str = None,
520520
method: str = ROTATE_UPDATE,
521521
):
522-
super().__init__(name=name)
522+
BrainPyObject.__init__(self, name=name)
523523
assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
524524
self.method = method
525525

brainpy/_src/train/back_propagation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def _step_func_fit(self, shared_args, inputs, targets):
566566

567567
def _step_func_predict(self, shared, x=None):
568568
assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.'
569-
bm.share.save(**shared)
569+
for k, v in shared.items():
570+
bm.share.save(k, v)
570571

571572
# input step
572573
self.target.clear_input()

brainpy/math/activations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,4 @@
2424
swish as swish,
2525
selu as selu,
2626
identity as identity,
27-
tanh as tanh,
2827
)

0 commit comments

Comments
 (0)