Skip to content

Commit 6f73232

Browse files
committed
enable monitoring GPU models on CPU when setting DSRunner(..., memory_efficient=True)
1 parent 544e749 commit 6f73232

File tree

3 files changed

+61
-28
lines changed

3 files changed

+61
-28
lines changed

brainpy/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@
113113

114114
from . import running, testing
115115
from ._src.visualization import (visualize as visualize)
116-
from ._src.running.runner import (Runner as Runner)
117116

118117

119118
# Part 7: Deprecations #

brainpy/_src/dyn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def check_post_attrs(self, *attrs):
767767
if not hasattr(self.post, attr):
768768
raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
769769

770-
def update(self, tdi, pre_spike=None):
770+
def update(self, *args, **kwargs):
771771
"""The function to specify the updating rule.
772772
773773
Assume any dynamical system depends on the shared variables (`sha`),

brainpy/_src/dyn/runners.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from brainpy._src.dyn.base import DynamicalSystem
1818
from brainpy._src.dyn.context import share
1919
from brainpy._src.running.runner import Runner
20-
from brainpy.check import is_float, serialize_kwargs
21-
from brainpy.errors import RunningError, NoLongerSupportError
20+
from brainpy.check import serialize_kwargs
21+
from brainpy.errors import RunningError
2222
from brainpy.types import ArrayType, Output, Monitor
2323

2424
__all__ = [
@@ -319,6 +319,7 @@ def __init__(
319319
# jit
320320
jit: Union[bool, Dict[str, bool]] = True,
321321
dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None,
322+
memory_efficient: bool = False,
322323

323324
# extra info
324325
dt: Optional[float] = None,
@@ -342,10 +343,9 @@ def __init__(
342343
numpy_mon_after_run=numpy_mon_after_run)
343344

344345
# t0 and i0
345-
is_float(t0, 't0', allow_none=False, allow_int=True)
346+
self.i0 = 0
346347
self._t0 = t0
347-
self.i0 = bm.Variable(jnp.asarray(1, dtype=bm.int_))
348-
self.t0 = bm.Variable(jnp.asarray(t0, dtype=bm.float_))
348+
self.t0 = t0
349349
if data_first_axis is None:
350350
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
351351
assert data_first_axis in ['B', 'T']
@@ -371,6 +371,11 @@ def __init__(
371371
# run function
372372
self._f_predict_compiled = dict()
373373

374+
# monitors
375+
self._memory_efficient = memory_efficient
376+
if memory_efficient and not numpy_mon_after_run:
377+
raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.')
378+
374379
def __repr__(self):
375380
name = self.__class__.__name__
376381
indent = " " * len(name) + ' '
@@ -382,8 +387,8 @@ def __repr__(self):
382387

383388
def reset_state(self):
384389
"""Reset state of the ``DSRunner``."""
385-
self.i0.value = jnp.zeros_like(self.i0.value)
386-
self.t0.value = jnp.ones_like(self.t0.value) * self._t0
390+
self.i0 = 0
391+
self.t0 = self._t0
387392

388393
def predict(
389394
self,
@@ -438,11 +443,12 @@ def predict(
438443
"""
439444

440445
if inputs_are_batching is not None:
441-
raise NoLongerSupportError(
446+
raise warnings.warn(
442447
f'''
443448
`inputs_are_batching` is no longer supported.
444449
The target mode of {self.target.mode} has already indicated the input should be batching.
445-
'''
450+
''',
451+
UserWarning
446452
)
447453
if duration is None:
448454
if inputs is None:
@@ -466,7 +472,7 @@ def predict(
466472
if shared_args is None:
467473
shared_args = dict()
468474
shared_args['fit'] = shared_args.get('fit', False)
469-
shared = tools.DotDict(i=jnp.arange(num_step, dtype=bm.int_))
475+
shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_))
470476
shared['t'] = shared['i'] * self.dt
471477
shared['i'] += self.i0
472478
shared['t'] += self.t0
@@ -486,7 +492,8 @@ def predict(
486492
# running
487493
if eval_time:
488494
t0 = time.time()
489-
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
495+
with jax.disable_jit(not self.jit['predict']):
496+
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
490497
if eval_time:
491498
running_time = time.time() - t0
492499

@@ -495,11 +502,16 @@ def predict(
495502
self._pbar.close()
496503

497504
# post-running for monitors
498-
hists['ts'] = shared['t'] + self.dt
499-
if self.numpy_mon_after_run:
500-
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
501-
for key in hists.keys():
502-
self.mon[key] = hists[key]
505+
if self._memory_efficient:
506+
self.mon['ts'] = shared['t'] + self.dt
507+
for key in self.mon.var_names:
508+
self.mon[key] = np.asarray(self.mon[key])
509+
else:
510+
hists['ts'] = shared['t'] + self.dt
511+
if self.numpy_mon_after_run:
512+
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
513+
for key in hists.keys():
514+
self.mon[key] = hists[key]
503515
self.i0 += num_step
504516
self.t0 += (num_step * self.dt if duration is None else duration)
505517
return outputs if not eval_time else (running_time, outputs)
@@ -609,10 +621,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
609621
raise ValueError(f'Number of time step is different across arrays in '
610622
f'the provided "xs". We got {set(num_steps)}.')
611623
return num_steps[0]
612-
613624
else:
614625
raise ValueError
615626

627+
def _step_mon_on_cpu(self, args, transforms):
628+
for key, val in args.items():
629+
self.mon[key].append(val)
630+
616631
def _step_func_predict(self, shared_args, t, i, x):
617632
# input step
618633
shared = tools.DotDict(t=t, i=i, dt=self.dt)
@@ -633,7 +648,12 @@ def _step_func_predict(self, shared_args, t, i, x):
633648
if self.progress_bar:
634649
id_tap(lambda *arg: self._pbar.update(), ())
635650
share.clear_shargs()
636-
return out, mon
651+
652+
if self._memory_efficient:
653+
id_tap(self._step_mon_on_cpu, mon)
654+
return out, None
655+
else:
656+
return out, mon
637657

638658
def _get_f_predict(self, shared_args: Dict = None):
639659
if shared_args is None:
@@ -646,16 +666,30 @@ def _get_f_predict(self, shared_args: Dict = None):
646666
dyn_vars.update(self.vars(level=0))
647667
dyn_vars = dyn_vars.unique()
648668

649-
def run_func(all_inputs):
650-
return bm.for_loop(partial(self._step_func_predict, shared_args),
651-
all_inputs,
652-
dyn_vars=dyn_vars,
653-
jit=self.jit['predict'])
669+
if self._memory_efficient:
670+
_jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars)
671+
672+
def run_func(all_inputs):
673+
outs = None
674+
times, indices, xs = all_inputs
675+
for i in range(times.shape[0]):
676+
out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs))
677+
if outs is None:
678+
outs = tree_map(lambda a: [], out)
679+
outs = tree_map(lambda a, o: o.append(a), out, outs)
680+
outs = tree_map(lambda a: bm.as_jax(a), outs)
681+
return outs, None
654682

655-
if self.jit['predict']:
656-
self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars)
657683
else:
658-
self._f_predict_compiled[shared_kwargs_str] = run_func
684+
@bm.jit(dyn_vars=dyn_vars)
685+
def run_func(all_inputs):
686+
return bm.for_loop(partial(self._step_func_predict, shared_args),
687+
all_inputs,
688+
dyn_vars=dyn_vars,
689+
jit=self.jit['predict'])
690+
691+
self._f_predict_compiled[shared_kwargs_str] = run_func
692+
659693
return self._f_predict_compiled[shared_kwargs_str]
660694

661695
def __del__(self):

0 commit comments

Comments
 (0)