Skip to content

Commit c043281

Browse files
committed
upgrade brainpy.train for new version of DynamicalSystem
1 parent 7c56adf commit c043281

File tree

6 files changed

+228
-326
lines changed

6 files changed

+228
-326
lines changed

brainpy/_src/runners.py

Lines changed: 81 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# -*- coding: utf-8 -*-
2-
2+
import functools
3+
import inspect
34
import time
45
import warnings
56
from collections.abc import Iterable
6-
from functools import partial
7-
from typing import Dict, Union, Sequence, Callable, Tuple, Optional
7+
from typing import Dict, Union, Sequence, Callable, Tuple, Optional, Any
88

99
import jax
1010
import jax.numpy as jnp
@@ -14,13 +14,12 @@
1414
from jax.tree_util import tree_map, tree_flatten
1515

1616
from brainpy import math as bm, tools
17-
from brainpy._src.dynsys import DynamicalSystem
1817
from brainpy._src.context import share
18+
from brainpy._src.deprecations import _input_deprecate_msg
19+
from brainpy._src.dynsys import DynamicalSystem
1920
from brainpy._src.running.runner import Runner
20-
from brainpy.check import serialize_kwargs
2121
from brainpy.errors import RunningError
22-
from brainpy.types import ArrayType, Output, Monitor
23-
22+
from brainpy.types import Output, Monitor
2423

2524
__all__ = [
2625
'DSRunner',
@@ -30,6 +29,16 @@
3029
SUPPORTED_INPUT_TYPE = ['fix', 'iter', 'func']
3130

3231

32+
def _call_fun_with_share(f, *args, **kwargs):
33+
try:
34+
sha = share.get_shargs()
35+
inspect.signature(f).bind(sha, *args, **kwargs)
36+
warnings.warn(_input_deprecate_msg, UserWarning)
37+
return f(sha, *args, **kwargs)
38+
except TypeError:
39+
return f(*args, **kwargs)
40+
41+
3342
def _is_brainpy_array(x):
3443
return isinstance(x, bm.Array)
3544

@@ -78,7 +87,6 @@ def check_and_format_inputs(host, inputs):
7887
# 2. get targets and attributes
7988
# ---------
8089
inputs_which_found_target = []
81-
inputs_not_found_target = []
8290

8391
# checking 1: absolute access
8492
# Check whether the input target node is accessible,
@@ -101,22 +109,6 @@ def check_and_format_inputs(host, inputs):
101109
f'specify variable of the target, but we got {key}.')
102110
inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
103111

104-
# checking 2: relative access
105-
# Check whether the input target node is accessible
106-
# and check whether the target node has the attribute
107-
# if len(inputs_not_found_target):
108-
# nodes = host.nodes(method='relative', level=-1, include_self=True)
109-
# for one_input in inputs_not_found_target:
110-
# splits = one_input[0].split('.')
111-
# target, key = '.'.join(splits[:-1]), splits[-1]
112-
# if target not in nodes:
113-
# raise RunningError(f'Input target "{target}" is not defined in {host}.')
114-
# real_target = nodes[target]
115-
# if not hasattr(real_target, key):
116-
# raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
117-
# real_target = getattr(real_target, key)
118-
# inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
119-
120112
# 3. format inputs
121113
# ---------
122114
formatted_inputs = []
@@ -257,7 +249,7 @@ class DSRunner(Runner):
257249
- A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']``
258250
- A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}``
259251
- A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}``
260-
- A dict with the callable function, like ``monitors={'a': lambda tdi: model.spike[:5]}``
252+
- A dict with the callable function, like ``monitors={'a': lambda: model.spike[:5]}``
261253
262254
.. versionchanged:: 2.3.1
263255
``fun_monitors`` are merged into ``monitors``.
@@ -266,8 +258,8 @@ class DSRunner(Runner):
266258
The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``.
267259
The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``.
268260
.. code-block::
269-
fun_monitors = {'spike': lambda tdi: model.spike[:10],
270-
'V10': lambda tdi: model.V[10]}
261+
fun_monitors = {'spike': lambda: model.spike[:10],
262+
'V10': lambda: model.V[10]}
271263
272264
.. deprecated:: 2.3.1
273265
Will be removed since version 2.4.0.
@@ -334,17 +326,16 @@ def __init__(
334326
if not isinstance(target, DynamicalSystem):
335327
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
336328
f'but we got {type(target)}: {target}')
337-
super(DSRunner, self).__init__(target=target,
338-
monitors=monitors,
339-
fun_monitors=fun_monitors,
340-
jit=jit,
341-
progress_bar=progress_bar,
342-
dyn_vars=dyn_vars,
343-
numpy_mon_after_run=numpy_mon_after_run)
329+
super().__init__(target=target,
330+
monitors=monitors,
331+
fun_monitors=fun_monitors,
332+
jit=jit,
333+
progress_bar=progress_bar,
334+
dyn_vars=dyn_vars,
335+
numpy_mon_after_run=numpy_mon_after_run)
344336

345337
# t0 and i0
346338
self.i0 = 0
347-
self._t0 = t0
348339
self.t0 = t0
349340
if data_first_axis is None:
350341
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
@@ -369,7 +360,7 @@ def __init__(
369360
self._inputs = check_and_format_inputs(host=target, inputs=inputs)
370361

371362
# run function
372-
self._f_predict_compiled = dict()
363+
self._jit_step_func_predict = bm.jit(self._step_func_predict, static_argnames=['shared_args'])
373364

374365
# monitors
375366
self._memory_efficient = memory_efficient
@@ -388,15 +379,15 @@ def __repr__(self):
388379
def reset_state(self):
389380
"""Reset state of the ``DSRunner``."""
390381
self.i0 = 0
391-
self.t0 = self._t0
382+
self.t0 = self.t0
392383

393384
def predict(
394385
self,
395386
duration: float = None,
396-
inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]] = None,
387+
inputs: Any = None,
397388
reset_state: bool = False,
398-
shared_args: Dict = None,
399389
eval_time: bool = False,
390+
shared_args: Dict = None,
400391

401392
# deprecated
402393
inputs_are_batching: bool = None,
@@ -431,10 +422,10 @@ def predict(
431422
Will be removed after version 2.4.0.
432423
reset_state: bool
433424
Whether reset the model states.
434-
shared_args: optional, dict
435-
The shared arguments across different layers.
436425
eval_time: bool
437426
Whether ro evaluate the running time.
427+
shared_args: optional, dict
428+
The shared arguments across different layers.
438429
439430
Returns
440431
-------
@@ -469,13 +460,7 @@ def predict(
469460
self.reset_state()
470461

471462
# shared arguments and inputs
472-
if shared_args is None:
473-
shared_args = dict()
474-
shared_args['fit'] = shared_args.get('fit', False)
475-
shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_))
476-
shared['t'] = shared['i'] * self.dt
477-
shared['i'] += self.i0
478-
shared['t'] += self.t0
463+
indices = np.arange(self.i0, self.i0 + num_step, dtype=bm.int_)
479464

480465
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
481466
inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs)
@@ -492,8 +477,11 @@ def predict(
492477
# running
493478
if eval_time:
494479
t0 = time.time()
495-
with jax.disable_jit(not self.jit['predict']):
496-
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
480+
if inputs is None:
481+
inputs = tuple()
482+
if not isinstance(inputs, (tuple, list)):
483+
inputs = (inputs,)
484+
outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
497485
if eval_time:
498486
running_time = time.time() - t0
499487

@@ -503,17 +491,18 @@ def predict(
503491

504492
# post-running for monitors
505493
if self._memory_efficient:
506-
self.mon['ts'] = shared['t'] + self.dt
494+
self.mon['ts'] = indices * self.dt + self.t0
507495
for key in self.mon.var_names:
508496
self.mon[key] = np.asarray(self.mon[key])
509497
else:
510-
hists['ts'] = shared['t'] + self.dt
498+
hists['ts'] = indices * self.dt + self.t0
511499
if self.numpy_mon_after_run:
512500
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
501+
else:
502+
hists['ts'] = bm.as_jax(hists['ts'])
513503
for key in hists.keys():
514504
self.mon[key] = hists[key]
515505
self.i0 += num_step
516-
self.t0 += (num_step * self.dt if duration is None else duration)
517506
return outputs if not eval_time else (running_time, outputs)
518507

519508
def run(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]:
@@ -526,17 +515,12 @@ def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]:
526515
"""
527516
return self.predict(*args, **kwargs)
528517

529-
def _predict(
530-
self,
531-
xs: Sequence,
532-
shared_args: Dict = None,
533-
) -> Union[Output, Monitor]:
518+
def _predict(self, indices, *xs, shared_args=None) -> Union[Output, Monitor]:
534519
"""Predict the output according to the inputs.
535520
536521
Parameters
537522
----------
538523
xs: sequence
539-
Must be a tuple/list of data, including `(times, indices, inputs)`.
540524
If `inputs` is not None, it should be a tensor with the shape of
541525
:math:`(num_time, ...)`.
542526
shared_args: optional, dict
@@ -547,18 +531,21 @@ def _predict(
547531
outputs, hists
548532
A tuple of pair of (outputs, hists).
549533
"""
550-
_predict_func = self._get_f_predict(shared_args)
551-
outs_and_mons = _predict_func(xs)
534+
if shared_args is None:
535+
shared_args = dict()
536+
shared_args = tools.DotDict(shared_args)
537+
538+
outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
552539
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
553540
outs_and_mons = tree_map(lambda x: jnp.moveaxis(x, 0, 1) if x.ndim >= 2 else x,
554541
outs_and_mons)
555542
return outs_and_mons
556543

557-
def _step_func_monitor(self, shared):
544+
def _step_func_monitor(self):
558545
res = dict()
559546
for key, val in self._monitors.items():
560547
if callable(val):
561-
res[key] = val(shared)
548+
res[key] = _call_fun_with_share(val)
562549
else:
563550
(variable, idx) = val
564551
if idx is None:
@@ -567,21 +554,21 @@ def _step_func_monitor(self, shared):
567554
res[key] = variable[bm.as_jax(idx)]
568555
return res
569556

570-
def _step_func_input(self, shared):
557+
def _step_func_input(self):
571558
if self._fun_inputs is not None:
572-
self._fun_inputs(shared)
559+
self._fun_inputs(share.get_shargs())
573560
if callable(self._inputs):
574-
self._inputs(shared)
561+
_call_fun_with_share(self._inputs)
575562
else:
576563
for ops, values in self._inputs['fixed'].items():
577564
for var, data in values:
578565
_f_ops(ops, var, data)
579566
for ops, values in self._inputs['array'].items():
580567
for var, data in values:
581-
_f_ops(ops, var, data[shared['i']])
568+
_f_ops(ops, var, data[share['i']])
582569
for ops, values in self._inputs['functional'].items():
583570
for var, data in values:
584-
_f_ops(ops, var, data(shared))
571+
_f_ops(ops, var, _call_fun_with_share(data))
585572
for ops, values in self._inputs['iterated'].items():
586573
for var, data in values:
587574
_f_ops(ops, var, next(data))
@@ -628,25 +615,24 @@ def _step_mon_on_cpu(self, args, transforms):
628615
for key, val in args.items():
629616
self.mon[key].append(val)
630617

631-
def _step_func_predict(self, shared_args, t, i, x):
618+
def _step_func_predict(self, i, *x, shared_args=None):
632619
# input step
633-
shared = tools.DotDict(t=t, i=i, dt=self.dt)
634-
shared.update(shared_args)
635-
share.save(**shared)
636-
self._step_func_input(shared)
620+
if shared_args is not None:
621+
assert isinstance(shared_args, dict)
622+
share.save(**shared_args)
623+
share.save(t=self.t0 + i * self.dt, i=i, dt=self.dt)
624+
self._step_func_input()
637625

638626
# dynamics update step
639-
args = () if x is None else (x,)
640-
out = self.target(*args)
627+
out = self.target(*x)
641628

642629
# monitor step
643-
shared['t'] += self.dt
644-
mon = self._step_func_monitor(shared)
630+
mon = self._step_func_monitor()
645631

646632
# finally
647633
if self.progress_bar:
648634
id_tap(lambda *arg: self._pbar.update(), ())
649-
share.clear_shargs()
635+
# share.clear_shargs()
650636
self.target.clear_input()
651637

652638
if self._memory_efficient:
@@ -655,40 +641,23 @@ def _step_func_predict(self, shared_args, t, i, x):
655641
else:
656642
return out, mon
657643

658-
def _get_f_predict(self, shared_args: Dict = None):
659-
if shared_args is None:
660-
shared_args = dict()
661-
662-
shared_kwargs_str = serialize_kwargs(shared_args)
663-
if shared_kwargs_str not in self._f_predict_compiled:
664-
665-
if self._memory_efficient:
666-
_jit_step = bm.jit(partial(self._step_func_predict, shared_args))
667-
668-
def run_func(all_inputs):
669-
outs = None
670-
times, indices, xs = all_inputs
671-
for i in range(times.shape[0]):
672-
out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs))
673-
if outs is None:
674-
outs = tree_map(lambda a: [], out)
675-
outs = tree_map(lambda a, o: o.append(a), out, outs)
676-
outs = tree_map(lambda a: bm.as_jax(a), outs)
677-
return outs, None
678-
644+
def _fun_predict(self, indices, *inputs, shared_args=None):
645+
if self._memory_efficient:
646+
if self.jit['predict']:
647+
run_fun = self._jit_step_func_predict
679648
else:
680-
step = partial(self._step_func_predict, shared_args)
649+
run_fun = self._step_func_predict
681650

682-
def run_func(all_inputs):
683-
return bm.for_loop(step, all_inputs, jit=self.jit['predict'])
684-
685-
self._f_predict_compiled[shared_kwargs_str] = run_func
686-
687-
return self._f_predict_compiled[shared_kwargs_str]
688-
689-
def __del__(self):
690-
if hasattr(self, '_f_predict_compiled'):
691-
for key in tuple(self._f_predict_compiled.keys()):
692-
self._f_predict_compiled.pop(key)
693-
super(DSRunner, self).__del__()
651+
outs = None
652+
for i in range(indices.shape[0]):
653+
out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs), shared_args=shared_args)
654+
if outs is None:
655+
outs = tree_map(lambda a: [], out)
656+
outs = tree_map(lambda a, o: o.append(a), out, outs)
657+
outs = tree_map(lambda a: bm.as_jax(a), outs)
658+
return outs, None
694659

660+
else:
661+
return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
662+
(indices, *inputs),
663+
jit=self.jit['predict'])

brainpy/_src/running/runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def __init__(
127127

128128
# dynamical changed variables
129129
self._dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
130-
self.register_implicit_vars(self._dyn_vars)
131130

132131
# numpy mon after run
133132
self.numpy_mon_after_run = numpy_mon_after_run

0 commit comments

Comments
 (0)