11# -*- coding: utf-8 -*-
2-
2+ import functools
3+ import inspect
34import time
45import warnings
56from collections .abc import Iterable
6- from functools import partial
7- from typing import Dict , Union , Sequence , Callable , Tuple , Optional
7+ from typing import Dict , Union , Sequence , Callable , Tuple , Optional , Any
88
99import jax
1010import jax .numpy as jnp
1414from jax .tree_util import tree_map , tree_flatten
1515
1616from brainpy import math as bm , tools
17- from brainpy ._src .dynsys import DynamicalSystem
1817from brainpy ._src .context import share
18+ from brainpy ._src .deprecations import _input_deprecate_msg
19+ from brainpy ._src .dynsys import DynamicalSystem
1920from brainpy ._src .running .runner import Runner
20- from brainpy .check import serialize_kwargs
2121from brainpy .errors import RunningError
22- from brainpy .types import ArrayType , Output , Monitor
23-
22+ from brainpy .types import Output , Monitor
2423
2524__all__ = [
2625 'DSRunner' ,
3029SUPPORTED_INPUT_TYPE = ['fix' , 'iter' , 'func' ]
3130
3231
32+ def _call_fun_with_share (f , * args , ** kwargs ):
33+ try :
34+ sha = share .get_shargs ()
35+ inspect .signature (f ).bind (sha , * args , ** kwargs )
36+ warnings .warn (_input_deprecate_msg , UserWarning )
37+ return f (sha , * args , ** kwargs )
38+ except TypeError :
39+ return f (* args , ** kwargs )
40+
41+
3342def _is_brainpy_array (x ):
3443 return isinstance (x , bm .Array )
3544
@@ -78,7 +87,6 @@ def check_and_format_inputs(host, inputs):
7887 # 2. get targets and attributes
7988 # ---------
8089 inputs_which_found_target = []
81- inputs_not_found_target = []
8290
8391 # checking 1: absolute access
8492 # Check whether the input target node is accessible,
@@ -101,22 +109,6 @@ def check_and_format_inputs(host, inputs):
101109 f'specify variable of the target, but we got { key } .' )
102110 inputs_which_found_target .append ((real_target ,) + tuple (one_input [1 :]))
103111
104- # checking 2: relative access
105- # Check whether the input target node is accessible
106- # and check whether the target node has the attribute
107- # if len(inputs_not_found_target):
108- # nodes = host.nodes(method='relative', level=-1, include_self=True)
109- # for one_input in inputs_not_found_target:
110- # splits = one_input[0].split('.')
111- # target, key = '.'.join(splits[:-1]), splits[-1]
112- # if target not in nodes:
113- # raise RunningError(f'Input target "{target}" is not defined in {host}.')
114- # real_target = nodes[target]
115- # if not hasattr(real_target, key):
116- # raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
117- # real_target = getattr(real_target, key)
118- # inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
119-
120112 # 3. format inputs
121113 # ---------
122114 formatted_inputs = []
@@ -257,7 +249,7 @@ class DSRunner(Runner):
257249 - A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']``
258250 - A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}``
259251 - A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}``
260- - A dict with the callable function, like ``monitors={'a': lambda tdi : model.spike[:5]}``
252+ - A dict with the callable function, like ``monitors={'a': lambda: model.spike[:5]}``
261253
262254 .. versionchanged:: 2.3.1
263255 ``fun_monitors`` are merged into ``monitors``.
@@ -266,8 +258,8 @@ class DSRunner(Runner):
266258 The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``.
267259 The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``.
268260 .. code-block::
269- fun_monitors = {'spike': lambda tdi : model.spike[:10],
270- 'V10': lambda tdi : model.V[10]}
261+ fun_monitors = {'spike': lambda: model.spike[:10],
262+ 'V10': lambda: model.V[10]}
271263
272264 .. deprecated:: 2.3.1
273265 Will be removed since version 2.4.0.
@@ -334,17 +326,16 @@ def __init__(
334326 if not isinstance (target , DynamicalSystem ):
335327 raise RunningError (f'"target" must be an instance of { DynamicalSystem .__name__ } , '
336328 f'but we got { type (target )} : { target } ' )
337- super (DSRunner , self ).__init__ (target = target ,
338- monitors = monitors ,
339- fun_monitors = fun_monitors ,
340- jit = jit ,
341- progress_bar = progress_bar ,
342- dyn_vars = dyn_vars ,
343- numpy_mon_after_run = numpy_mon_after_run )
329+ super ().__init__ (target = target ,
330+ monitors = monitors ,
331+ fun_monitors = fun_monitors ,
332+ jit = jit ,
333+ progress_bar = progress_bar ,
334+ dyn_vars = dyn_vars ,
335+ numpy_mon_after_run = numpy_mon_after_run )
344336
345337 # t0 and i0
346338 self .i0 = 0
347- self ._t0 = t0
348339 self .t0 = t0
349340 if data_first_axis is None :
350341 data_first_axis = 'B' if isinstance (self .target .mode , bm .BatchingMode ) else 'T'
@@ -369,7 +360,7 @@ def __init__(
369360 self ._inputs = check_and_format_inputs (host = target , inputs = inputs )
370361
371362 # run function
372- self ._f_predict_compiled = dict ( )
363+ self ._jit_step_func_predict = bm . jit ( self . _step_func_predict , static_argnames = [ 'shared_args' ] )
373364
374365 # monitors
375366 self ._memory_efficient = memory_efficient
@@ -388,15 +379,15 @@ def __repr__(self):
388379 def reset_state (self ):
389380 """Reset state of the ``DSRunner``."""
390381 self .i0 = 0
391- self .t0 = self ._t0
382+ self .t0 = self .t0
392383
393384 def predict (
394385 self ,
395386 duration : float = None ,
396- inputs : Union [ ArrayType , Sequence [ ArrayType ], Dict [ str , ArrayType ]] = None ,
387+ inputs : Any = None ,
397388 reset_state : bool = False ,
398- shared_args : Dict = None ,
399389 eval_time : bool = False ,
390+ shared_args : Dict = None ,
400391
401392 # deprecated
402393 inputs_are_batching : bool = None ,
@@ -431,10 +422,10 @@ def predict(
431422 Will be removed after version 2.4.0.
432423 reset_state: bool
433424 Whether reset the model states.
434- shared_args: optional, dict
435- The shared arguments across different layers.
436425 eval_time: bool
437426 Whether ro evaluate the running time.
427+ shared_args: optional, dict
428+ The shared arguments across different layers.
438429
439430 Returns
440431 -------
@@ -469,13 +460,7 @@ def predict(
469460 self .reset_state ()
470461
471462 # shared arguments and inputs
472- if shared_args is None :
473- shared_args = dict ()
474- shared_args ['fit' ] = shared_args .get ('fit' , False )
475- shared = tools .DotDict (i = np .arange (num_step , dtype = bm .int_ ))
476- shared ['t' ] = shared ['i' ] * self .dt
477- shared ['i' ] += self .i0
478- shared ['t' ] += self .t0
463+ indices = np .arange (self .i0 , self .i0 + num_step , dtype = bm .int_ )
479464
480465 if isinstance (self .target .mode , bm .BatchingMode ) and self .data_first_axis == 'B' :
481466 inputs = tree_map (lambda x : jnp .moveaxis (x , 0 , 1 ), inputs )
@@ -492,8 +477,11 @@ def predict(
492477 # running
493478 if eval_time :
494479 t0 = time .time ()
495- with jax .disable_jit (not self .jit ['predict' ]):
496- outputs , hists = self ._predict (xs = (shared ['t' ], shared ['i' ], inputs ), shared_args = shared_args )
480+ if inputs is None :
481+ inputs = tuple ()
482+ if not isinstance (inputs , (tuple , list )):
483+ inputs = (inputs ,)
484+ outputs , hists = self ._predict (indices , * inputs , shared_args = shared_args )
497485 if eval_time :
498486 running_time = time .time () - t0
499487
@@ -503,17 +491,18 @@ def predict(
503491
504492 # post-running for monitors
505493 if self ._memory_efficient :
506- self .mon ['ts' ] = shared [ 't' ] + self .dt
494+ self .mon ['ts' ] = indices * self .dt + self . t0
507495 for key in self .mon .var_names :
508496 self .mon [key ] = np .asarray (self .mon [key ])
509497 else :
510- hists ['ts' ] = shared [ 't' ] + self .dt
498+ hists ['ts' ] = indices * self .dt + self . t0
511499 if self .numpy_mon_after_run :
512500 hists = tree_map (lambda a : np .asarray (a ), hists , is_leaf = lambda a : isinstance (a , bm .Array ))
501+ else :
502+ hists ['ts' ] = bm .as_jax (hists ['ts' ])
513503 for key in hists .keys ():
514504 self .mon [key ] = hists [key ]
515505 self .i0 += num_step
516- self .t0 += (num_step * self .dt if duration is None else duration )
517506 return outputs if not eval_time else (running_time , outputs )
518507
519508 def run (self , * args , ** kwargs ) -> Union [Output , Tuple [float , Output ]]:
@@ -526,17 +515,12 @@ def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]:
526515 """
527516 return self .predict (* args , ** kwargs )
528517
529- def _predict (
530- self ,
531- xs : Sequence ,
532- shared_args : Dict = None ,
533- ) -> Union [Output , Monitor ]:
518+ def _predict (self , indices , * xs , shared_args = None ) -> Union [Output , Monitor ]:
534519 """Predict the output according to the inputs.
535520
536521 Parameters
537522 ----------
538523 xs: sequence
539- Must be a tuple/list of data, including `(times, indices, inputs)`.
540524 If `inputs` is not None, it should be a tensor with the shape of
541525 :math:`(num_time, ...)`.
542526 shared_args: optional, dict
@@ -547,18 +531,21 @@ def _predict(
547531 outputs, hists
548532 A tuple of pair of (outputs, hists).
549533 """
550- _predict_func = self ._get_f_predict (shared_args )
551- outs_and_mons = _predict_func (xs )
534+ if shared_args is None :
535+ shared_args = dict ()
536+ shared_args = tools .DotDict (shared_args )
537+
538+ outs_and_mons = self ._fun_predict (indices , * xs , shared_args = shared_args )
552539 if isinstance (self .target .mode , bm .BatchingMode ) and self .data_first_axis == 'B' :
553540 outs_and_mons = tree_map (lambda x : jnp .moveaxis (x , 0 , 1 ) if x .ndim >= 2 else x ,
554541 outs_and_mons )
555542 return outs_and_mons
556543
557- def _step_func_monitor (self , shared ):
544+ def _step_func_monitor (self ):
558545 res = dict ()
559546 for key , val in self ._monitors .items ():
560547 if callable (val ):
561- res [key ] = val ( shared )
548+ res [key ] = _call_fun_with_share ( val )
562549 else :
563550 (variable , idx ) = val
564551 if idx is None :
@@ -567,21 +554,21 @@ def _step_func_monitor(self, shared):
567554 res [key ] = variable [bm .as_jax (idx )]
568555 return res
569556
570- def _step_func_input (self , shared ):
557+ def _step_func_input (self ):
571558 if self ._fun_inputs is not None :
572- self ._fun_inputs (shared )
559+ self ._fun_inputs (share . get_shargs () )
573560 if callable (self ._inputs ):
574- self ._inputs ( shared )
561+ _call_fun_with_share ( self ._inputs )
575562 else :
576563 for ops , values in self ._inputs ['fixed' ].items ():
577564 for var , data in values :
578565 _f_ops (ops , var , data )
579566 for ops , values in self ._inputs ['array' ].items ():
580567 for var , data in values :
581- _f_ops (ops , var , data [shared ['i' ]])
568+ _f_ops (ops , var , data [share ['i' ]])
582569 for ops , values in self ._inputs ['functional' ].items ():
583570 for var , data in values :
584- _f_ops (ops , var , data ( shared ))
571+ _f_ops (ops , var , _call_fun_with_share ( data ))
585572 for ops , values in self ._inputs ['iterated' ].items ():
586573 for var , data in values :
587574 _f_ops (ops , var , next (data ))
@@ -628,25 +615,24 @@ def _step_mon_on_cpu(self, args, transforms):
628615 for key , val in args .items ():
629616 self .mon [key ].append (val )
630617
631- def _step_func_predict (self , shared_args , t , i , x ):
618+ def _step_func_predict (self , i , * x , shared_args = None ):
632619 # input step
633- shared = tools .DotDict (t = t , i = i , dt = self .dt )
634- shared .update (shared_args )
635- share .save (** shared )
636- self ._step_func_input (shared )
620+ if shared_args is not None :
621+ assert isinstance (shared_args , dict )
622+ share .save (** shared_args )
623+ share .save (t = self .t0 + i * self .dt , i = i , dt = self .dt )
624+ self ._step_func_input ()
637625
638626 # dynamics update step
639- args = () if x is None else (x ,)
640- out = self .target (* args )
627+ out = self .target (* x )
641628
642629 # monitor step
643- shared ['t' ] += self .dt
644- mon = self ._step_func_monitor (shared )
630+ mon = self ._step_func_monitor ()
645631
646632 # finally
647633 if self .progress_bar :
648634 id_tap (lambda * arg : self ._pbar .update (), ())
649- share .clear_shargs ()
635+ # share.clear_shargs()
650636 self .target .clear_input ()
651637
652638 if self ._memory_efficient :
@@ -655,40 +641,23 @@ def _step_func_predict(self, shared_args, t, i, x):
655641 else :
656642 return out , mon
657643
658- def _get_f_predict (self , shared_args : Dict = None ):
659- if shared_args is None :
660- shared_args = dict ()
661-
662- shared_kwargs_str = serialize_kwargs (shared_args )
663- if shared_kwargs_str not in self ._f_predict_compiled :
664-
665- if self ._memory_efficient :
666- _jit_step = bm .jit (partial (self ._step_func_predict , shared_args ))
667-
668- def run_func (all_inputs ):
669- outs = None
670- times , indices , xs = all_inputs
671- for i in range (times .shape [0 ]):
672- out , _ = _jit_step (times [i ], indices [i ], tree_map (lambda a : a [i ], xs ))
673- if outs is None :
674- outs = tree_map (lambda a : [], out )
675- outs = tree_map (lambda a , o : o .append (a ), out , outs )
676- outs = tree_map (lambda a : bm .as_jax (a ), outs )
677- return outs , None
678-
644+ def _fun_predict (self , indices , * inputs , shared_args = None ):
645+ if self ._memory_efficient :
646+ if self .jit ['predict' ]:
647+ run_fun = self ._jit_step_func_predict
679648 else :
680- step = partial ( self ._step_func_predict , shared_args )
649+ run_fun = self ._step_func_predict
681650
682- def run_func (all_inputs ):
683- return bm .for_loop (step , all_inputs , jit = self .jit ['predict' ])
684-
685- self ._f_predict_compiled [shared_kwargs_str ] = run_func
686-
687- return self ._f_predict_compiled [shared_kwargs_str ]
688-
689- def __del__ (self ):
690- if hasattr (self , '_f_predict_compiled' ):
691- for key in tuple (self ._f_predict_compiled .keys ()):
692- self ._f_predict_compiled .pop (key )
693- super (DSRunner , self ).__del__ ()
651+ outs = None
652+ for i in range (indices .shape [0 ]):
653+ out , _ = run_fun (indices [i ], * tree_map (lambda a : a [i ], inputs ), shared_args = shared_args )
654+ if outs is None :
655+ outs = tree_map (lambda a : [], out )
656+ outs = tree_map (lambda a , o : o .append (a ), out , outs )
657+ outs = tree_map (lambda a : bm .as_jax (a ), outs )
658+ return outs , None
694659
660+ else :
661+ return bm .for_loop (functools .partial (self ._step_func_predict , shared_args = shared_args ),
662+ (indices , * inputs ),
663+ jit = self .jit ['predict' ])
0 commit comments