Skip to content

Commit 00c790a

Browse files
authored
Merge pull request #355 from chaoming0625/master
Enable memory-efficient ``DSRunner``
2 parents 99f22b6 + 6f73232 commit 00c790a

File tree

7 files changed

+252
-190
lines changed

7 files changed

+252
-190
lines changed

brainpy/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@
8181
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
8282
# DynamicalSystem base classes
8383
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
84-
NeuGroupNS as NeuGroupNS)
84+
NeuGroupNS as NeuGroupNS,
85+
TwoEndConnNS as TwoEndConnNS,
86+
)
8587
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
8688
SynSTPNS as SynSTPNS,
8789
SynConnNS as SynConnNS, )
@@ -111,7 +113,6 @@
111113

112114
from . import running, testing
113115
from ._src.visualization import (visualize as visualize)
114-
from ._src.running.runner import (Runner as Runner)
115116

116117

117118
# Part 7: Deprecations #

brainpy/_src/dyn/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def check_post_attrs(self, *attrs):
767767
if not hasattr(self.post, attr):
768768
raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
769769

770-
def update(self, tdi, pre_spike=None):
770+
def update(self, *args, **kwargs):
771771
"""The function to specify the updating rule.
772772
773773
Assume any dynamical system depends on the shared variables (`sha`),
@@ -1024,6 +1024,11 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
10241024
return post_vs
10251025

10261026

1027+
class TwoEndConnNS(TwoEndConn):
1028+
"""Two-end connection without passing shared arguments."""
1029+
_pass_shared_args = False
1030+
1031+
10271032
class CondNeuGroup(NeuGroup, Container):
10281033
r"""Base class to model conductance-based neuron group.
10291034

brainpy/_src/dyn/runners.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from brainpy._src.dyn.base import DynamicalSystem
1818
from brainpy._src.dyn.context import share
1919
from brainpy._src.running.runner import Runner
20-
from brainpy.check import is_float, serialize_kwargs
21-
from brainpy.errors import RunningError, NoLongerSupportError
20+
from brainpy.check import serialize_kwargs
21+
from brainpy.errors import RunningError
2222
from brainpy.types import ArrayType, Output, Monitor
2323

2424
__all__ = [
@@ -319,6 +319,7 @@ def __init__(
319319
# jit
320320
jit: Union[bool, Dict[str, bool]] = True,
321321
dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None,
322+
memory_efficient: bool = False,
322323

323324
# extra info
324325
dt: Optional[float] = None,
@@ -342,10 +343,9 @@ def __init__(
342343
numpy_mon_after_run=numpy_mon_after_run)
343344

344345
# t0 and i0
345-
is_float(t0, 't0', allow_none=False, allow_int=True)
346+
self.i0 = 0
346347
self._t0 = t0
347-
self.i0 = bm.Variable(jnp.asarray(1, dtype=bm.int_))
348-
self.t0 = bm.Variable(jnp.asarray(t0, dtype=bm.float_))
348+
self.t0 = t0
349349
if data_first_axis is None:
350350
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
351351
assert data_first_axis in ['B', 'T']
@@ -371,6 +371,11 @@ def __init__(
371371
# run function
372372
self._f_predict_compiled = dict()
373373

374+
# monitors
375+
self._memory_efficient = memory_efficient
376+
if memory_efficient and not numpy_mon_after_run:
377+
raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.')
378+
374379
def __repr__(self):
375380
name = self.__class__.__name__
376381
indent = " " * len(name) + ' '
@@ -382,8 +387,8 @@ def __repr__(self):
382387

383388
def reset_state(self):
384389
"""Reset state of the ``DSRunner``."""
385-
self.i0.value = jnp.zeros_like(self.i0.value)
386-
self.t0.value = jnp.ones_like(self.t0.value) * self._t0
390+
self.i0 = 0
391+
self.t0 = self._t0
387392

388393
def predict(
389394
self,
@@ -438,11 +443,12 @@ def predict(
438443
"""
439444

440445
if inputs_are_batching is not None:
441-
raise NoLongerSupportError(
446+
raise warnings.warn(
442447
f'''
443448
`inputs_are_batching` is no longer supported.
444449
The target mode of {self.target.mode} has already indicated the input should be batching.
445-
'''
450+
''',
451+
UserWarning
446452
)
447453
if duration is None:
448454
if inputs is None:
@@ -466,7 +472,7 @@ def predict(
466472
if shared_args is None:
467473
shared_args = dict()
468474
shared_args['fit'] = shared_args.get('fit', False)
469-
shared = tools.DotDict(i=jnp.arange(num_step, dtype=bm.int_))
475+
shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_))
470476
shared['t'] = shared['i'] * self.dt
471477
shared['i'] += self.i0
472478
shared['t'] += self.t0
@@ -486,7 +492,8 @@ def predict(
486492
# running
487493
if eval_time:
488494
t0 = time.time()
489-
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
495+
with jax.disable_jit(not self.jit['predict']):
496+
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
490497
if eval_time:
491498
running_time = time.time() - t0
492499

@@ -495,11 +502,16 @@ def predict(
495502
self._pbar.close()
496503

497504
# post-running for monitors
498-
hists['ts'] = shared['t'] + self.dt
499-
if self.numpy_mon_after_run:
500-
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
501-
for key in hists.keys():
502-
self.mon[key] = hists[key]
505+
if self._memory_efficient:
506+
self.mon['ts'] = shared['t'] + self.dt
507+
for key in self.mon.var_names:
508+
self.mon[key] = np.asarray(self.mon[key])
509+
else:
510+
hists['ts'] = shared['t'] + self.dt
511+
if self.numpy_mon_after_run:
512+
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
513+
for key in hists.keys():
514+
self.mon[key] = hists[key]
503515
self.i0 += num_step
504516
self.t0 += (num_step * self.dt if duration is None else duration)
505517
return outputs if not eval_time else (running_time, outputs)
@@ -609,10 +621,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
609621
raise ValueError(f'Number of time step is different across arrays in '
610622
f'the provided "xs". We got {set(num_steps)}.')
611623
return num_steps[0]
612-
613624
else:
614625
raise ValueError
615626

627+
def _step_mon_on_cpu(self, args, transforms):
628+
for key, val in args.items():
629+
self.mon[key].append(val)
630+
616631
def _step_func_predict(self, shared_args, t, i, x):
617632
# input step
618633
shared = tools.DotDict(t=t, i=i, dt=self.dt)
@@ -633,7 +648,12 @@ def _step_func_predict(self, shared_args, t, i, x):
633648
if self.progress_bar:
634649
id_tap(lambda *arg: self._pbar.update(), ())
635650
share.clear_shargs()
636-
return out, mon
651+
652+
if self._memory_efficient:
653+
id_tap(self._step_mon_on_cpu, mon)
654+
return out, None
655+
else:
656+
return out, mon
637657

638658
def _get_f_predict(self, shared_args: Dict = None):
639659
if shared_args is None:
@@ -646,16 +666,30 @@ def _get_f_predict(self, shared_args: Dict = None):
646666
dyn_vars.update(self.vars(level=0))
647667
dyn_vars = dyn_vars.unique()
648668

649-
def run_func(all_inputs):
650-
return bm.for_loop(partial(self._step_func_predict, shared_args),
651-
all_inputs,
652-
dyn_vars=dyn_vars,
653-
jit=self.jit['predict'])
669+
if self._memory_efficient:
670+
_jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars)
671+
672+
def run_func(all_inputs):
673+
outs = None
674+
times, indices, xs = all_inputs
675+
for i in range(times.shape[0]):
676+
out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs))
677+
if outs is None:
678+
outs = tree_map(lambda a: [], out)
679+
outs = tree_map(lambda a, o: o.append(a), out, outs)
680+
outs = tree_map(lambda a: bm.as_jax(a), outs)
681+
return outs, None
654682

655-
if self.jit['predict']:
656-
self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars)
657683
else:
658-
self._f_predict_compiled[shared_kwargs_str] = run_func
684+
@bm.jit(dyn_vars=dyn_vars)
685+
def run_func(all_inputs):
686+
return bm.for_loop(partial(self._step_func_predict, shared_args),
687+
all_inputs,
688+
dyn_vars=dyn_vars,
689+
jit=self.jit['predict'])
690+
691+
self._f_predict_compiled[shared_kwargs_str] = run_func
692+
659693
return self._f_predict_compiled[shared_kwargs_str]
660694

661695
def __del__(self):

brainpy/_src/initialize/generic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def variable(
168168
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
169169
return bm.Variable(init(size))
170170
elif isinstance(batch_size_or_mode, bm.BatchingMode):
171-
new_shape = size[:batch_axis] + (1,) + size[batch_axis:]
171+
new_shape = size[:batch_axis] + (batch_size_or_mode.batch_size,) + size[batch_axis:]
172172
return bm.Variable(init(new_shape), batch_axis=batch_axis)
173173
elif batch_size_or_mode in (None, False):
174174
return bm.Variable(init(size))
@@ -185,7 +185,10 @@ def variable(
185185
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
186186
return bm.Variable(init)
187187
elif isinstance(batch_size_or_mode, bm.BatchingMode):
188-
return bm.Variable(bm.expand_dims(init, axis=batch_axis), batch_axis=batch_axis)
188+
return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
189+
batch_size_or_mode.batch_size,
190+
axis=batch_axis),
191+
batch_axis=batch_axis)
189192
elif batch_size_or_mode in (None, False):
190193
return bm.Variable(init)
191194
elif isinstance(batch_size_or_mode, int):

brainpy/_src/math/environment.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,20 +490,23 @@ class training_environment(environment):
490490
491491
"""
492492

493-
def __init__(self,
494-
dt: float = None,
495-
x64: bool = None,
496-
complex_: type = None,
497-
float_: type = None,
498-
int_: type = None,
499-
bool_: type = None):
493+
def __init__(
494+
self,
495+
dt: float = None,
496+
x64: bool = None,
497+
complex_: type = None,
498+
float_: type = None,
499+
int_: type = None,
500+
bool_: type = None,
501+
batch_size: int = 1,
502+
):
500503
super().__init__(dt=dt,
501504
x64=x64,
502505
complex_=complex_,
503506
float_=float_,
504507
int_=int_,
505508
bool_=bool_,
506-
mode=modes.TrainingMode())
509+
mode=modes.TrainingMode(batch_size))
507510

508511

509512
class batching_environment(environment):
@@ -519,20 +522,23 @@ class batching_environment(environment):
519522
520523
"""
521524

522-
def __init__(self,
523-
dt: float = None,
524-
x64: bool = None,
525-
complex_: type = None,
526-
float_: type = None,
527-
int_: type = None,
528-
bool_: type = None):
525+
def __init__(
526+
self,
527+
dt: float = None,
528+
x64: bool = None,
529+
complex_: type = None,
530+
float_: type = None,
531+
int_: type = None,
532+
bool_: type = None,
533+
batch_size: int = 1,
534+
):
529535
super().__init__(dt=dt,
530536
x64=x64,
531537
complex_=complex_,
532538
float_=float_,
533539
int_=int_,
534540
bool_=bool_,
535-
mode=modes.BatchingMode())
541+
mode=modes.BatchingMode(batch_size))
536542

537543

538544
def enable_x64():

brainpy/_src/math/modes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __eq__(self, other: 'Mode'):
2626
def is_a(self, mode: type):
2727
assert isinstance(mode, type)
2828
return self.__class__ == mode
29-
29+
3030
def is_parent_of(self, *modes):
3131
cls = self.__class__
3232
for smode in modes:
@@ -58,7 +58,12 @@ class BatchingMode(Mode):
5858
5959
:py:class:`~.NonBatchingMode` is usually used in models of model trainings.
6060
"""
61-
pass
61+
62+
def __init__(self, batch_size: int = 1):
63+
self.batch_size = batch_size
64+
65+
def __repr__(self):
66+
return f'{self.__class__.__name__}(batch_size={self.batch_size})'
6267

6368

6469
class TrainingMode(BatchingMode):
@@ -74,4 +79,3 @@ class TrainingMode(BatchingMode):
7479

7580
training_mode = TrainingMode()
7681
'''Default instance of the training computation mode.'''
77-

0 commit comments

Comments
 (0)