Skip to content

Commit 3422efc

Browse files
authored
fix #294: remove VariableView in dyn_vars of a runner (#295)
fix #294: remove VariableView in `dyn_vars` of a runner
2 parents 431a8b7 + a99fe51 commit 3422efc

File tree

7 files changed

+64
-6
lines changed

7 files changed

+64
-6
lines changed

brainpy/dyn/runners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def _step_func(t, i, x):
585585
if self.jit['predict']:
586586
dyn_vars = self.target.vars()
587587
dyn_vars.update(self.dyn_vars)
588+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
588589
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)
589590

590591
else:

brainpy/dyn/tests/test_dyn_runner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,53 @@ def update(self, tdi):
3232
runner = bp.dyn.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False)
3333
runner.run(100.)
3434

35+
def test_DSView(self):
36+
class EINet(bp.dyn.Network):
37+
def __init__(self, scale=1.0, method='exp_auto'):
38+
super(EINet, self).__init__()
39+
40+
# network size
41+
num_exc = int(800 * scale)
42+
num_inh = int(200 * scale)
43+
44+
# neurons
45+
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
46+
self.E = bp.neurons.LIF(num_exc, **pars, method=method)
47+
self.I = bp.neurons.LIF(num_inh, **pars, method=method)
48+
self.E.V[:] = bm.random.randn(num_exc) * 2 - 55.
49+
self.I.V[:] = bm.random.randn(num_inh) * 2 - 55.
50+
51+
# synapses
52+
we = 0.6 / scale # excitatory synaptic weight (voltage)
53+
wi = 6.7 / scale # inhibitory synaptic weight
54+
self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02),
55+
output=bp.synouts.COBA(E=0.), g_max=we,
56+
tau=5., method=method)
57+
self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02),
58+
output=bp.synouts.COBA(E=0.), g_max=we,
59+
tau=5., method=method)
60+
self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02),
61+
output=bp.synouts.COBA(E=-80.), g_max=wi,
62+
tau=10., method=method)
63+
self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02),
64+
output=bp.synouts.COBA(E=-80.), g_max=wi,
65+
tau=10., method=method)
66+
67+
net = EINet(scale=1., method='exp_auto')
68+
# with JIT
69+
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike},
70+
inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.)
71+
72+
# without JIT
73+
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike},
74+
inputs=[(net.E.input, 20.), (net.I.input, 20.)],
75+
jit=False).run(0.2)
76+
77+
78+
79+
80+
81+
3582
# class TestMonitor(TestCase):
3683
# def test_1d_array(self):
3784
# try1 = TryGroup(monitors=['a'])

brainpy/measure/correlation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def voltage_fluctuation(potentials, numpy=True, method='loop'):
166166
References
167167
----------
168168
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
169-
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814.
169+
inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814.
170170
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
171171
inhibitory neurons. Physica D 72:259-282.
172172
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.

brainpy/running/multiprocess.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def process_pool(func, all_params, num_process):
2121
----------
2222
func : callable
2323
The function to run model.
24-
all_params : a_list, tuple
24+
all_params : list, tuple, dict
2525
The parameters of the function arguments.
2626
The parameters for each process can be a tuple, or a dictionary.
2727
num_process : int
@@ -47,7 +47,7 @@ def process_pool(func, all_params, num_process):
4747
return [r.get() for r in results]
4848

4949

50-
def process_pool_lock(func, all_net_params, nb_process):
50+
def process_pool_lock(func, all_params, nb_process):
5151
"""Run multiple models in multi-processes with lock.
5252
5353
Sometimes, you want to synchronize the processes. For example,
@@ -73,7 +73,7 @@ def some_func(..., lock, ...):
7373
----------
7474
func : callable
7575
The function to run model.
76-
all_net_params : a_list, tuple
76+
all_params : list, tuple, dict
7777
The parameters of the function arguments.
7878
nb_process : int
7979
The number of the processes.
@@ -83,12 +83,12 @@ def some_func(..., lock, ...):
8383
results : list
8484
Process results.
8585
"""
86-
print('{} jobs total.'.format(len(all_net_params)))
86+
print('{} jobs total.'.format(len(all_params)))
8787
pool = multiprocessing.Pool(processes=nb_process)
8888
m = multiprocessing.Manager()
8989
lock = m.Lock()
9090
results = []
91-
for net_params in all_net_params:
91+
for net_params in all_params:
9292
if isinstance(net_params, (list, tuple)):
9393
results.append(pool.apply_async(func, args=tuple(net_params) + (lock,)))
9494
elif isinstance(net_params, dict):

brainpy/train/back_propagation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def loss_fun(inputs, targets):
300300
if self.jit[c.LOSS_PHASE] and jit:
301301
dyn_vars = self.target.vars()
302302
dyn_vars.update(self.dyn_vars)
303+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
303304
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
304305
dyn_vars=dyn_vars)
305306
return self._f_loss_compiled[shared_args_str]
@@ -311,6 +312,7 @@ def f_grad(self, shared_args=None) -> Callable:
311312
_f_loss_internal = self.f_loss(shared_args, jit=False)
312313
dyn_vars = self.target.vars()
313314
dyn_vars.update(self.dyn_vars)
315+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
314316
tran_vars = dyn_vars.subset(bm.TrainVar)
315317
grad_f = bm.grad(_f_loss_internal,
316318
dyn_vars=dyn_vars.unique(),
@@ -339,6 +341,7 @@ def train_func(inputs, targets):
339341
dyn_vars = self.target.vars()
340342
dyn_vars.update(self.dyn_vars)
341343
dyn_vars.update(self.optimizer.vars())
344+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
342345
self._f_train_compiled[shared_args_str] = bm.jit(train_func, dyn_vars=dyn_vars.unique())
343346
else:
344347
self._f_train_compiled[shared_args_str] = train_func
@@ -453,6 +456,7 @@ def loss_fun(inputs, targets):
453456
if self.jit[c.LOSS_PHASE] and jit:
454457
dyn_vars = self.target.vars()
455458
dyn_vars.update(self.dyn_vars)
459+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
456460
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
457461
dyn_vars=dyn_vars)
458462
else:
@@ -480,6 +484,7 @@ def run_func(xs):
480484
if self.jit[c.PREDICT_PHASE] and jit:
481485
dyn_vars = self.target.vars()
482486
dyn_vars.update(self.dyn_vars)
487+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
483488
self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique())
484489
else:
485490
self._f_predict_compiled[shared_args_str] = run_func
@@ -505,6 +510,7 @@ def loss_fun(t, i, input_, target_):
505510
if self.jit[c.LOSS_PHASE] and jit:
506511
dyn_vars = self.target.vars()
507512
dyn_vars.update(self.dyn_vars)
513+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
508514
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
509515
dyn_vars=dyn_vars)
510516
else:
@@ -529,6 +535,7 @@ def train_step(*x):
529535
if self.jit[c.FIT_PHASE]:
530536
dyn_vars = self.target.vars()
531537
dyn_vars.update(self.dyn_vars)
538+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
532539
run_func = lambda all_inputs: bm.for_loop(train_step, dyn_vars.unique(), all_inputs)
533540

534541
else:
@@ -582,6 +589,7 @@ def run_func(t, i, x):
582589
if self.jit[c.FIT_PHASE] and jit:
583590
dyn_vars = self.target.vars()
584591
dyn_vars.update(self.dyn_vars)
592+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
585593
self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique())
586594
else:
587595
self._f_predict_compiled[shared_args_str] = run_func

brainpy/train/offline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def train_func(monitor_data: Dict[str, Array], target_data: Dict[str, Array]):
231231
if self.jit['fit']:
232232
dyn_vars = self.target.vars()
233233
dyn_vars.update(self.dyn_vars)
234+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
234235
train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique())
235236
return train_func
236237

brainpy/train/online.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def _step_func(t, i, x, ys):
261261
if self.jit['fit']:
262262
dyn_vars = self.target.vars()
263263
dyn_vars.update(self.dyn_vars)
264+
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
264265
return lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)
265266

266267
else:

0 commit comments

Comments
 (0)