Skip to content

Commit 4cba72d

Browse files
committed
support brainpy.math.for_loop with the keyword unroll_kwargs
1 parent 15ae3ae commit 4cba72d

File tree

9 files changed

+31
-191
lines changed

9 files changed

+31
-191
lines changed

brainpy/_src/math/object_transform/controls.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,12 @@ def _get_for_loop_transform(
723723
remat: bool,
724724
reverse: bool,
725725
unroll: int,
726+
unroll_kwargs: tools.DotDict
726727
):
727728
def fun2scan(carry, x):
728729
for k in dyn_vars.keys():
729730
dyn_vars[k]._value = carry[k]
730-
results = body_fun(*x)
731+
results = body_fun(*x, **unroll_kwargs)
731732
if progress_bar:
732733
id_tap(lambda *arg: bar.update(), ())
733734
return dyn_vars.dict_data(), results
@@ -860,6 +861,7 @@ def for_loop(
860861

861862
if unroll_kwargs is None:
862863
unroll_kwargs = dict()
864+
unroll_kwargs = tools.DotDict(unroll_kwargs)
863865

864866
if not isinstance(operands, (list, tuple)):
865867
operands = (operands,)
@@ -871,19 +873,20 @@ def for_loop(
871873

872874
if jit is None: # jax disable jit
873875
jit = not jax.config.jax_disable_jit
874-
dyn_vars = get_stack_cache(body_fun)
876+
dyn_vars = get_stack_cache((body_fun, unroll_kwargs))
875877
if jit:
876878
if dyn_vars is None:
877879
# TODO: better cache mechanism?
878880
with new_transform('for_loop'):
879881
with VariableStack() as dyn_vars:
880882
transform = _get_for_loop_transform(body_fun, VariableStack(), bar,
881-
progress_bar, remat, reverse, unroll)
883+
progress_bar, remat, reverse, unroll,
884+
unroll_kwargs)
882885
if current_transform_number() > 1:
883886
rets = transform(operands)
884887
else:
885888
rets = jax.eval_shape(transform, operands)
886-
cache_stack(body_fun, dyn_vars) # cache
889+
cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache
887890
if current_transform_number():
888891
return rets[1]
889892
del rets
@@ -893,7 +896,7 @@ def for_loop(
893896
# TODO: cache mechanism?
894897
transform = _get_for_loop_transform(body_fun, dyn_vars, bar,
895898
progress_bar, remat, reverse,
896-
unroll)
899+
unroll, unroll_kwargs)
897900
if jit:
898901
dyn_vals, out_vals = transform(operands)
899902
else:

brainpy/_src/runners.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def predict(
466466
inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs)
467467

468468
# build monitor
469-
for key in self.mon.var_names:
469+
for key in self._monitors.keys():
470470
self.mon[key] = [] # reshape the monitor items
471471

472472
# init progress bar
@@ -492,7 +492,7 @@ def predict(
492492
# post-running for monitors
493493
if self._memory_efficient:
494494
self.mon['ts'] = indices * self.dt + self.t0
495-
for key in self.mon.var_names:
495+
for key in self._monitors.keys():
496496
self.mon[key] = np.asarray(self.mon[key])
497497
else:
498498
hists['ts'] = indices * self.dt + self.t0
@@ -658,6 +658,7 @@ def _fun_predict(self, indices, *inputs, shared_args=None):
658658
return outs, None
659659

660660
else:
661-
return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
661+
return bm.for_loop(self._step_func_predict,
662662
(indices, *inputs),
663-
jit=self.jit['predict'])
663+
jit=self.jit['predict'],
664+
unroll_kwargs={'shared_args': shared_args})

brainpy/_src/running/runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def __init__(
118118

119119
# monitor for user access
120120
self.mon = DotDict()
121-
self.mon['var_names'] = tuple(self._monitors.keys())
122121

123122
# progress bar
124123
assert isinstance(progress_bar, bool), 'Must be a boolean variable.'

brainpy/_src/tools/dicts.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -42,64 +42,20 @@ class DotDict(dict):
4242
>>> f(d)
4343
TypeError: Argument 'a' of type <class 'str'> is not a valid JAX type.
4444
45-
At this moment, you can label this attribute `names` as not a key in the dictionary
46-
by using the syntax::
47-
48-
>>> d.add_attr_not_key('names')
49-
>>> f(d)
50-
{'a': DeviceArray(10, dtype=int32, weak_type=True),
51-
'b': DeviceArray(20, dtype=int32, weak_type=True),
52-
'c': DeviceArray(30, dtype=int32, weak_type=True)}
53-
5445
"""
5546

56-
'''Used to exclude variables that '''
57-
attrs_not_keys = ('attrs_not_keys', 'var_names')
58-
5947
def __init__(self, *args, **kwargs):
6048
super().__init__(*args, **kwargs)
6149
self.__dict__ = self
62-
self.var_names = ()
6350

6451
def copy(self) -> 'DotDict':
6552
return type(self)(super().copy())
6653

67-
def keys(self):
68-
"""Retrieve all keys in the dict, excluding ignored keys."""
69-
keys = []
70-
for k in super(DotDict, self).keys():
71-
if k not in self.attrs_not_keys:
72-
keys.append(k)
73-
return tuple(keys)
74-
75-
def values(self):
76-
"""Retrieve all values in the dict, excluding values of ignored keys."""
77-
values = []
78-
for k, v in super(DotDict, self).items():
79-
if k not in self.attrs_not_keys:
80-
values.append(v)
81-
return tuple(values)
82-
83-
def items(self):
84-
"""Retrieve all items in the dict, excluding ignored items."""
85-
items = []
86-
for k, v in super(DotDict, self).items():
87-
if k not in self.attrs_not_keys:
88-
items.append((k, v))
89-
return items
90-
9154
def to_numpy(self):
9255
"""Change all values to numpy arrays."""
9356
for key in tuple(self.keys()):
9457
self[key] = np.asarray(self[key])
9558

96-
def add_attr_not_key(self, *args):
97-
"""Add excluded attribute when retrieving dictionary keys. """
98-
for arg in args:
99-
if not isinstance(arg, str):
100-
raise TypeError('Only support string.')
101-
self.attrs_not_keys += args
102-
10359
def update(self, *args, **kwargs):
10460
super().update(*args, **kwargs)
10561
return self
@@ -179,7 +135,7 @@ def subset(self, var_type):
179135
180136
>>> import brainpy as bp
181137
>>>
182-
>>> some_collector = Collector()
138+
>>> some_collector = DotDict()
183139
>>>
184140
>>> # get all trainable variables
185141
>>> some_collector.subset(bp.math.TrainVar)

brainpy/_src/train/back_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def predict(
605605
self.target.reset_state(self._get_input_batch_size(xs=inputs))
606606
self.reset_state()
607607
# init monitor
608-
for key in self.mon.var_names:
608+
for key in self._monitors.keys():
609609
self.mon[key] = [] # reshape the monitor items
610610
# prediction
611611
if not isinstance(inputs, (tuple, list)):

brainpy/_src/train/online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def fit(
177177
is_leaf=lambda y: isinstance(y, bm.Array))
178178

179179
# init monitor
180-
for key in self.mon.var_names:
180+
for key in self._monitors.keys():
181181
self.mon[key] = [] # reshape the monitor items
182182

183183
# init progress bar

examples/dynamics_simulation/COBA.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def run3():
168168
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
169169

170170

171-
172171
def run4():
173172
net = EICOBA_PostAlign(3200, 800)
174173
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})

tests/simulation/test_net_COBA.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

tests/training/test_ESN.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
class ESN(bp.DynamicalSystem):
77
def __init__(self, num_in, num_hidden, num_out):
88
super(ESN, self).__init__()
9-
self.r = bp.layers.Reservoir(num_in,
10-
num_hidden,
11-
Win_initializer=bp.init.Uniform(-0.1, 0.1),
12-
Wrec_initializer=bp.init.Normal(scale=0.1),
13-
in_connectivity=0.02,
14-
rec_connectivity=0.02,
15-
comp_type='dense')
16-
self.o = bp.layers.Dense(num_hidden,
17-
num_out,
18-
W_initializer=bp.init.Normal(),
19-
mode=bm.training_mode)
9+
self.r = bp.dnn.Reservoir(num_in,
10+
num_hidden,
11+
Win_initializer=bp.init.Uniform(-0.1, 0.1),
12+
Wrec_initializer=bp.init.Normal(scale=0.1),
13+
in_connectivity=0.02,
14+
rec_connectivity=0.02,
15+
comp_type='dense')
16+
self.o = bp.dnn.Dense(num_hidden,
17+
num_out,
18+
W_initializer=bp.init.Normal(),
19+
mode=bm.training_mode)
2020

2121
def update(self, x):
2222
return x >> self.r >> self.o
@@ -26,10 +26,10 @@ class NGRC(bp.DynamicalSystem):
2626
def __init__(self, num_in, num_out):
2727
super(NGRC, self).__init__()
2828

29-
self.r = bp.layers.NVAR(num_in, delay=2, order=2)
30-
self.o = bp.layers.Dense(self.r.num_out, num_out,
31-
W_initializer=bp.init.Normal(0.1),
32-
mode=bm.training_mode)
29+
self.r = bp.dnn.NVAR(num_in, delay=2, order=2)
30+
self.o = bp.dnn.Dense(self.r.num_out, num_out,
31+
W_initializer=bp.init.Normal(0.1),
32+
mode=bm.training_mode)
3333

3434
def update(self, x):
3535
return x >> self.r >> self.o

0 commit comments

Comments
 (0)