1717from brainpy ._src .dyn .base import DynamicalSystem
1818from brainpy ._src .dyn .context import share
1919from brainpy ._src .running .runner import Runner
20- from brainpy .check import is_float , serialize_kwargs
21- from brainpy .errors import RunningError , NoLongerSupportError
20+ from brainpy .check import serialize_kwargs
21+ from brainpy .errors import RunningError
2222from brainpy .types import ArrayType , Output , Monitor
2323
2424__all__ = [
@@ -319,6 +319,7 @@ def __init__(
319319 # jit
320320 jit : Union [bool , Dict [str , bool ]] = True ,
321321 dyn_vars : Optional [Union [bm .Variable , Sequence [bm .Variable ], Dict [str , bm .Variable ]]] = None ,
322+ memory_efficient : bool = False ,
322323
323324 # extra info
324325 dt : Optional [float ] = None ,
@@ -342,10 +343,9 @@ def __init__(
342343 numpy_mon_after_run = numpy_mon_after_run )
343344
344345 # t0 and i0
345- is_float ( t0 , 't0' , allow_none = False , allow_int = True )
346+ self . i0 = 0
346347 self ._t0 = t0
347- self .i0 = bm .Variable (jnp .asarray (1 , dtype = bm .int_ ))
348- self .t0 = bm .Variable (jnp .asarray (t0 , dtype = bm .float_ ))
348+ self .t0 = t0
349349 if data_first_axis is None :
350350 data_first_axis = 'B' if isinstance (self .target .mode , bm .BatchingMode ) else 'T'
351351 assert data_first_axis in ['B' , 'T' ]
@@ -371,6 +371,11 @@ def __init__(
371371 # run function
372372 self ._f_predict_compiled = dict ()
373373
374+ # monitors
375+ self ._memory_efficient = memory_efficient
376+ if memory_efficient and not numpy_mon_after_run :
377+ raise ValueError ('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.' )
378+
374379 def __repr__ (self ):
375380 name = self .__class__ .__name__
376381 indent = " " * len (name ) + ' '
@@ -382,8 +387,8 @@ def __repr__(self):
382387
383388 def reset_state (self ):
384389 """Reset state of the ``DSRunner``."""
385- self .i0 . value = jnp . zeros_like ( self . i0 . value )
386- self .t0 . value = jnp . ones_like ( self . t0 . value ) * self ._t0
390+ self .i0 = 0
391+ self .t0 = self ._t0
387392
388393 def predict (
389394 self ,
@@ -438,11 +443,12 @@ def predict(
438443 """
439444
440445 if inputs_are_batching is not None :
441- raise NoLongerSupportError (
446+ raise warnings . warn (
442447 f'''
443448 `inputs_are_batching` is no longer supported.
444449 The target mode of { self .target .mode } has already indicated the input should be batching.
445- '''
450+ ''' ,
451+ UserWarning
446452 )
447453 if duration is None :
448454 if inputs is None :
@@ -466,7 +472,7 @@ def predict(
466472 if shared_args is None :
467473 shared_args = dict ()
468474 shared_args ['fit' ] = shared_args .get ('fit' , False )
469- shared = tools .DotDict (i = jnp .arange (num_step , dtype = bm .int_ ))
475+ shared = tools .DotDict (i = np .arange (num_step , dtype = bm .int_ ))
470476 shared ['t' ] = shared ['i' ] * self .dt
471477 shared ['i' ] += self .i0
472478 shared ['t' ] += self .t0
@@ -486,7 +492,8 @@ def predict(
486492 # running
487493 if eval_time :
488494 t0 = time .time ()
489- outputs , hists = self ._predict (xs = (shared ['t' ], shared ['i' ], inputs ), shared_args = shared_args )
495+ with jax .disable_jit (not self .jit ['predict' ]):
496+ outputs , hists = self ._predict (xs = (shared ['t' ], shared ['i' ], inputs ), shared_args = shared_args )
490497 if eval_time :
491498 running_time = time .time () - t0
492499
@@ -495,11 +502,16 @@ def predict(
495502 self ._pbar .close ()
496503
497504 # post-running for monitors
498- hists ['ts' ] = shared ['t' ] + self .dt
499- if self .numpy_mon_after_run :
500- hists = tree_map (lambda a : np .asarray (a ), hists , is_leaf = lambda a : isinstance (a , bm .Array ))
501- for key in hists .keys ():
502- self .mon [key ] = hists [key ]
505+ if self ._memory_efficient :
506+ self .mon ['ts' ] = shared ['t' ] + self .dt
507+ for key in self .mon .var_names :
508+ self .mon [key ] = np .asarray (self .mon [key ])
509+ else :
510+ hists ['ts' ] = shared ['t' ] + self .dt
511+ if self .numpy_mon_after_run :
512+ hists = tree_map (lambda a : np .asarray (a ), hists , is_leaf = lambda a : isinstance (a , bm .Array ))
513+ for key in hists .keys ():
514+ self .mon [key ] = hists [key ]
503515 self .i0 += num_step
504516 self .t0 += (num_step * self .dt if duration is None else duration )
505517 return outputs if not eval_time else (running_time , outputs )
@@ -609,10 +621,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
609621 raise ValueError (f'Number of time step is different across arrays in '
610622 f'the provided "xs". We got { set (num_steps )} .' )
611623 return num_steps [0 ]
612-
613624 else :
614625 raise ValueError
615626
627+ def _step_mon_on_cpu (self , args , transforms ):
628+ for key , val in args .items ():
629+ self .mon [key ].append (val )
630+
616631 def _step_func_predict (self , shared_args , t , i , x ):
617632 # input step
618633 shared = tools .DotDict (t = t , i = i , dt = self .dt )
@@ -633,7 +648,12 @@ def _step_func_predict(self, shared_args, t, i, x):
633648 if self .progress_bar :
634649 id_tap (lambda * arg : self ._pbar .update (), ())
635650 share .clear_shargs ()
636- return out , mon
651+
652+ if self ._memory_efficient :
653+ id_tap (self ._step_mon_on_cpu , mon )
654+ return out , None
655+ else :
656+ return out , mon
637657
638658 def _get_f_predict (self , shared_args : Dict = None ):
639659 if shared_args is None :
@@ -646,16 +666,30 @@ def _get_f_predict(self, shared_args: Dict = None):
646666 dyn_vars .update (self .vars (level = 0 ))
647667 dyn_vars = dyn_vars .unique ()
648668
649- def run_func (all_inputs ):
650- return bm .for_loop (partial (self ._step_func_predict , shared_args ),
651- all_inputs ,
652- dyn_vars = dyn_vars ,
653- jit = self .jit ['predict' ])
669+ if self ._memory_efficient :
670+ _jit_step = bm .jit (partial (self ._step_func_predict , shared_args ), dyn_vars = dyn_vars )
671+
672+ def run_func (all_inputs ):
673+ outs = None
674+ times , indices , xs = all_inputs
675+ for i in range (times .shape [0 ]):
676+ out , _ = _jit_step (times [i ], indices [i ], tree_map (lambda a : a [i ], xs ))
677+ if outs is None :
678+ outs = tree_map (lambda a : [], out )
679+ outs = tree_map (lambda a , o : o .append (a ), out , outs )
680+ outs = tree_map (lambda a : bm .as_jax (a ), outs )
681+ return outs , None
654682
655- if self .jit ['predict' ]:
656- self ._f_predict_compiled [shared_kwargs_str ] = bm .jit (run_func , dyn_vars = dyn_vars )
657683 else :
658- self ._f_predict_compiled [shared_kwargs_str ] = run_func
684+ @bm .jit (dyn_vars = dyn_vars )
685+ def run_func (all_inputs ):
686+ return bm .for_loop (partial (self ._step_func_predict , shared_args ),
687+ all_inputs ,
688+ dyn_vars = dyn_vars ,
689+ jit = self .jit ['predict' ])
690+
691+ self ._f_predict_compiled [shared_kwargs_str ] = run_func
692+
659693 return self ._f_predict_compiled [shared_kwargs_str ]
660694
661695 def __del__ (self ):
0 commit comments