3131def _is_brainpy_array (x ):
3232 return isinstance (x , bm .Array )
3333
34+
3435def check_and_format_inputs (host , inputs ):
3536 """Check inputs and get the formatted inputs for the given population.
3637
@@ -292,10 +293,10 @@ class DSRunner(Runner):
292293 numpy_mon_after_run : bool
293294 When finishing the network running, transform the JAX arrays into numpy ndarray or not?
294295
295- time_major: bool
296+ data_first_axis: str
296297 Set the default data dimension arrangement.
297- To indicate whether the first axis is the batch size (``time_major=False ``) or the
298- time length (``time_major=True ``).
298+ To indicate whether the first axis is the batch size (``data_first_axis='B' ``) or the
299+ time length (``data_first_axis='T' ``).
299300 In order to be compatible with previous API, default is set to be ``False``.
300301
301302 .. versionadded:: 2.3.1
@@ -311,22 +312,22 @@ def __init__(
311312 inputs : Union [Sequence , Callable ] = (),
312313
313314 # monitors
314- monitors : Union [Sequence , Dict ] = None ,
315+ monitors : Optional [ Union [Sequence , Dict ] ] = None ,
315316 numpy_mon_after_run : bool = True ,
316317
317318 # jit
318319 jit : Union [bool , Dict [str , bool ]] = True ,
319- dyn_vars : Union [bm .Variable , Sequence [bm .Variable ], Dict [str , bm .Variable ]] = None ,
320+ dyn_vars : Optional [ Union [bm .Variable , Sequence [bm .Variable ], Dict [str , bm .Variable ] ]] = None ,
320321
321322 # extra info
322- dt : float = None ,
323+ dt : Optional [ float ] = None ,
323324 t0 : Union [float , int ] = 0. ,
324325 progress_bar : bool = True ,
325- time_major : bool = False ,
326+ data_first_axis : Optional [ str ] = None ,
326327
327328 # deprecated
328- fun_inputs : Callable = None ,
329- fun_monitors : Dict [str , Callable ] = None ,
329+ fun_inputs : Optional [ Callable ] = None ,
330+ fun_monitors : Optional [ Dict [str , Callable ] ] = None ,
330331 ):
331332 if not isinstance (target , DynamicalSystem ):
332333 raise RunningError (f'"target" must be an instance of { DynamicalSystem .__name__ } , '
@@ -344,7 +345,10 @@ def __init__(
344345 self ._t0 = t0
345346 self .i0 = bm .Variable (bm .asarray ([1 ], dtype = bm .int_ ))
346347 self .t0 = bm .Variable (bm .asarray ([t0 ], dtype = bm .float_ ))
347- self .time_major = time_major
348+ if data_first_axis is None :
349+ data_first_axis = 'B' if isinstance (self .target , bm .BatchingMode ) else 'T'
350+ assert data_first_axis in ['B' , 'T' ]
351+ self .data_first_axis = data_first_axis
348352
349353 # parameters
350354 dt = bm .get_dt () if dt is None else dt
@@ -372,7 +376,7 @@ def __repr__(self):
372376 return (f'{ name } (target={ tools .repr_context (str (self .target ), indent2 )} , \n '
373377 f'{ indent } jit={ self .jit } ,\n '
374378 f'{ indent } dt={ self .dt } ,\n '
375- f'{ indent } time_major ={ self .time_major } )' )
379+ f'{ indent } data_first_axis ={ self .data_first_axis } )' )
376380
377381 def reset_state (self ):
378382 """Reset state of the ``DSRunner``."""
@@ -407,8 +411,8 @@ def predict(
407411
408412 - If the mode of ``target`` is instance of :py:class:`~.BatchingMode`,
409413 ``inputs`` must be a PyTree of data with two dimensions:
410- ``(batch, time, ...)`` when ``time_major=False ``,
411- or ``(time, batch, ...)`` when ``time_major=True ``.
414+ ``(batch, time, ...)`` when ``data_first_axis='B' ``,
415+ or ``(time, batch, ...)`` when ``data_first_axis='T' ``.
412416 - If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`,
413417 the ``inputs`` should be a PyTree of data with one dimension:
414418 ``(time, ...)``.
@@ -462,7 +466,7 @@ def predict(
462466 shared ['i' ] += self .i0
463467 shared ['t' ] += self .t0
464468
465- if isinstance (self .target .mode , bm .BatchingMode ) and not self .time_major :
469+ if isinstance (self .target .mode , bm .BatchingMode ) and self .data_first_axis == 'B' :
466470 inputs = tree_map (lambda x : bm .moveaxis (x , 0 , 1 ),
467471 inputs ,
468472 is_leaf = lambda x : isinstance (x , bm .Array ))
@@ -530,7 +534,7 @@ def _predict(
530534 """
531535 _predict_func = self ._get_f_predict (shared_args )
532536 outs_and_mons = _predict_func (xs )
533- if isinstance (self .target .mode , bm .BatchingMode ) and not self .time_major :
537+ if isinstance (self .target .mode , bm .BatchingMode ) and self .data_first_axis == 'B' :
534538 outs_and_mons = tree_map (lambda x : bm .moveaxis (x , 0 , 1 ),
535539 outs_and_mons ,
536540 is_leaf = lambda x : isinstance (x , bm .Array ))
@@ -573,9 +577,9 @@ def _get_input_batch_size(self, xs=None) -> Optional[int]:
573577 if isinstance (self .target .mode , bm .NonBatchingMode ):
574578 return None
575579 if isinstance (xs , (bm .Array , jax .Array , np .ndarray )):
576- return xs .shape [1 ] if self .time_major else xs .shape [0 ]
580+ return xs .shape [1 ] if self .data_first_axis == 'T' else xs .shape [0 ]
577581 leaves , _ = tree_flatten (xs , is_leaf = _is_brainpy_array )
578- if self .time_major :
582+ if self .data_first_axis == 'T' :
579583 num_batch_sizes = [x .shape [1 ] for x in leaves ]
580584 else :
581585 num_batch_sizes = [x .shape [0 ] for x in leaves ]
@@ -590,19 +594,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
590594 return int (duration / self .dt )
591595 if xs is not None :
592596 if isinstance (xs , (bm .Array , jnp .ndarray )):
593- if isinstance (self .target .mode , bm .BatchingMode ):
594- return xs .shape [0 ] if self .time_major else xs .shape [1 ]
595- else :
596- return xs .shape [0 ]
597+ return xs .shape [0 ] if self .data_first_axis == 'T' else xs .shape [1 ]
597598 else :
598599 leaves , _ = tree_flatten (xs , is_leaf = lambda x : isinstance (x , bm .Array ))
599- if isinstance (self .target .mode , bm .BatchingMode ):
600- if self .time_major :
601- num_steps = [x .shape [0 ] for x in leaves ]
602- else :
603- num_steps = [x .shape [1 ] for x in leaves ]
604- else :
600+ if self .data_first_axis == 'T' :
605601 num_steps = [x .shape [0 ] for x in leaves ]
602+ else :
603+ num_steps = [x .shape [1 ] for x in leaves ]
606604 if len (set (num_steps )) != 1 :
607605 raise ValueError (f'Number of time step is different across arrays in '
608606 f'the provided "xs". We got { set (num_steps )} .' )
0 commit comments