Skip to content

Commit b97b15b

Browse files
committed
use for_loop() instead of make_loop()
1 parent b9231e5 commit b97b15b

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

brainpy/dyn/runners.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ def f_predict(self, shared_args: Dict = None):
566566

567567
monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)
568568

569-
def _step_func(inputs):
570-
t, i, x = inputs
569+
def _step_func(t, i, x):
571570
self.target.clear_input()
572571
# input step
573572
shared = DotDict(t=t, i=i, dt=self.dt)
@@ -586,8 +585,7 @@ def _step_func(inputs):
586585
if self.jit['predict']:
587586
dyn_vars = self.target.vars()
588587
dyn_vars.update(self.dyn_vars)
589-
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
590-
run_func = lambda all_inputs: f(all_inputs)[1]
588+
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)
591589

592590
else:
593591
def run_func(xs):
@@ -601,7 +599,7 @@ def run_func(xs):
601599
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))
602600

603601
# step at the i
604-
output, mon = _step_func((times[i], indices[i], x))
602+
output, mon = _step_func(times[i], indices[i], x)
605603

606604
# append output and monitor
607605
outputs.append(output)

brainpy/integrators/runner.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,12 @@ def __init__(
217217

218218
# build the update step
219219
if self.jit['predict']:
220-
_loop_func = bm.make_loop(
221-
self._step,
222-
dyn_vars=self.dyn_vars,
223-
out_vars={k: self.variables[k] for k in self.monitors.keys()},
224-
has_return=True
225-
)
220+
def _loop_func(times):
221+
return bm.for_loop(self._step, self.dyn_vars, times)
226222
else:
227223
def _loop_func(times):
228-
out_vars = {k: [] for k in self.monitors.keys()}
229224
returns = {k: [] for k in self.fun_monitors.keys()}
225+
returns.update({k: [] for k in self.monitors.keys()})
230226
for i in range(len(times)):
231227
_t = times[i]
232228
_dt = self.dt
@@ -237,9 +233,9 @@ def _loop_func(times):
237233
self._step(_t)
238234
# variable monitors
239235
for k in self.monitors.keys():
240-
out_vars[k].append(bm.as_device_array(self.variables[k]))
241-
out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()}
242-
return out_vars, returns
236+
returns[k].append(bm.as_device_array(self.variables[k]))
237+
returns = {k: bm.asarray(returns[k]) for k in returns.keys()}
238+
return returns
243239
self.step_func = _loop_func
244240

245241
def _step(self, t):
@@ -252,11 +248,6 @@ def _step(self, t):
252248
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
253249
self.idx += 1
254250

255-
# return of function monitors
256-
returns = dict()
257-
for key, func in self.fun_monitors.items():
258-
returns[key] = func(t, self.dt)
259-
260251
# call integrator function
261252
update_values = self.target(**kwargs)
262253
if len(self.target.variables) == 1:
@@ -268,6 +259,13 @@ def _step(self, t):
268259
# progress bar
269260
if self.progress_bar:
270261
id_tap(lambda *args: self._pbar.update(), ())
262+
263+
# return of function monitors
264+
returns = dict()
265+
for key, func in self.fun_monitors.items():
266+
returns[key] = func(t, self.dt)
267+
for k in self.monitors.keys():
268+
returns[k] = self.variables[k].value
271269
return returns
272270

273271
def run(self, duration, start_t=None, eval_time=False):
@@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False):
302300
refresh=True)
303301
if eval_time:
304302
t0 = time.time()
305-
hists, returns = self.step_func(times)
303+
hists = self.step_func(times)
306304
if eval_time:
307305
running_time = time.time() - t0
308306
if self.progress_bar:
309307
self._pbar.close()
310308

311309
# post-running
312-
hists.update(returns)
313310
times += self.dt
314311
if self.numpy_mon_after_run:
315312
times = np.asarray(times)

brainpy/train/online.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ def _make_fit_func(self, shared_args: Dict):
234234

235235
monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)
236236

237-
def _step_func(all_inputs):
238-
t, i, x, ys = all_inputs
237+
def _step_func(t, i, x, ys):
239238
shared = DotDict(t=t, dt=self.dt, i=i)
240239

241240
# input step
@@ -262,8 +261,7 @@ def _step_func(all_inputs):
262261
if self.jit['fit']:
263262
dyn_vars = self.target.vars()
264263
dyn_vars.update(self.dyn_vars)
265-
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
266-
return lambda all_inputs: f(all_inputs)[1]
264+
return lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)
267265

268266
else:
269267
def run_func(all_inputs):
@@ -273,7 +271,7 @@ def run_func(all_inputs):
273271
for i in range(times.shape[0]):
274272
x = tree_map(lambda x: x[i], xs)
275273
y = tree_map(lambda x: x[i], ys)
276-
output, mon = _step_func((times[i], indices[i], x, y))
274+
output, mon = _step_func(times[i], indices[i], x, y)
277275
outputs.append(output)
278276
for key, value in mon.items():
279277
monitors[key].append(value)

0 commit comments

Comments
 (0)