Skip to content

Commit 0658244

Browse files
committed
change data format specification from time_major to data_first_axis
1 parent 77306b6 commit 0658244

File tree

4 files changed

+41
-41
lines changed

4 files changed

+41
-41
lines changed

brainpy/dyn/runners.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def _is_brainpy_array(x):
3232
return isinstance(x, bm.Array)
3333

34+
3435
def check_and_format_inputs(host, inputs):
3536
"""Check inputs and get the formatted inputs for the given population.
3637
@@ -292,10 +293,10 @@ class DSRunner(Runner):
292293
numpy_mon_after_run : bool
293294
When finishing the network running, transform the JAX arrays into numpy ndarray or not?
294295
295-
time_major: bool
296+
data_first_axis: str
296297
Set the default data dimension arrangement.
297-
To indicate whether the first axis is the batch size (``time_major=False``) or the
298-
time length (``time_major=True``).
298+
To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the
299+
time length (``data_first_axis='T'``).
299300
In order to be compatible with previous API, default is set to be ``False``.
300301
301302
.. versionadded:: 2.3.1
@@ -311,22 +312,22 @@ def __init__(
311312
inputs: Union[Sequence, Callable] = (),
312313

313314
# monitors
314-
monitors: Union[Sequence, Dict] = None,
315+
monitors: Optional[Union[Sequence, Dict]] = None,
315316
numpy_mon_after_run: bool = True,
316317

317318
# jit
318319
jit: Union[bool, Dict[str, bool]] = True,
319-
dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
320+
dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None,
320321

321322
# extra info
322-
dt: float = None,
323+
dt: Optional[float] = None,
323324
t0: Union[float, int] = 0.,
324325
progress_bar: bool = True,
325-
time_major: bool = False,
326+
data_first_axis: Optional[str] = None,
326327

327328
# deprecated
328-
fun_inputs: Callable = None,
329-
fun_monitors: Dict[str, Callable] = None,
329+
fun_inputs: Optional[Callable] = None,
330+
fun_monitors: Optional[Dict[str, Callable]] = None,
330331
):
331332
if not isinstance(target, DynamicalSystem):
332333
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
@@ -344,7 +345,10 @@ def __init__(
344345
self._t0 = t0
345346
self.i0 = bm.Variable(bm.asarray([1], dtype=bm.int_))
346347
self.t0 = bm.Variable(bm.asarray([t0], dtype=bm.float_))
347-
self.time_major = time_major
348+
if data_first_axis is None:
349+
data_first_axis = 'B' if isinstance(self.target, bm.BatchingMode) else 'T'
350+
assert data_first_axis in ['B', 'T']
351+
self.data_first_axis = data_first_axis
348352

349353
# parameters
350354
dt = bm.get_dt() if dt is None else dt
@@ -372,7 +376,7 @@ def __repr__(self):
372376
return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n'
373377
f'{indent}jit={self.jit},\n'
374378
f'{indent}dt={self.dt},\n'
375-
f'{indent}time_major={self.time_major})')
379+
f'{indent}data_first_axis={self.data_first_axis})')
376380

377381
def reset_state(self):
378382
"""Reset state of the ``DSRunner``."""
@@ -407,8 +411,8 @@ def predict(
407411
408412
- If the mode of ``target`` is instance of :py:class:`~.BatchingMode`,
409413
``inputs`` must be a PyTree of data with two dimensions:
410-
``(batch, time, ...)`` when ``time_major=False``,
411-
or ``(time, batch, ...)`` when ``time_major=True``.
414+
``(batch, time, ...)`` when ``data_first_axis='B'``,
415+
or ``(time, batch, ...)`` when ``data_first_axis='T'``.
412416
- If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`,
413417
the ``inputs`` should be a PyTree of data with one dimension:
414418
``(time, ...)``.
@@ -462,7 +466,7 @@ def predict(
462466
shared['i'] += self.i0
463467
shared['t'] += self.t0
464468

465-
if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
469+
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
466470
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
467471
inputs,
468472
is_leaf=lambda x: isinstance(x, bm.Array))
@@ -530,7 +534,7 @@ def _predict(
530534
"""
531535
_predict_func = self._get_f_predict(shared_args)
532536
outs_and_mons = _predict_func(xs)
533-
if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
537+
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
534538
outs_and_mons = tree_map(lambda x: bm.moveaxis(x, 0, 1),
535539
outs_and_mons,
536540
is_leaf=lambda x: isinstance(x, bm.Array))
@@ -573,9 +577,9 @@ def _get_input_batch_size(self, xs=None) -> Optional[int]:
573577
if isinstance(self.target.mode, bm.NonBatchingMode):
574578
return None
575579
if isinstance(xs, (bm.Array, jax.Array, np.ndarray)):
576-
return xs.shape[1] if self.time_major else xs.shape[0]
580+
return xs.shape[1] if self.data_first_axis == 'T' else xs.shape[0]
577581
leaves, _ = tree_flatten(xs, is_leaf=_is_brainpy_array)
578-
if self.time_major:
582+
if self.data_first_axis == 'T':
579583
num_batch_sizes = [x.shape[1] for x in leaves]
580584
else:
581585
num_batch_sizes = [x.shape[0] for x in leaves]
@@ -590,19 +594,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
590594
return int(duration / self.dt)
591595
if xs is not None:
592596
if isinstance(xs, (bm.Array, jnp.ndarray)):
593-
if isinstance(self.target.mode, bm.BatchingMode):
594-
return xs.shape[0] if self.time_major else xs.shape[1]
595-
else:
596-
return xs.shape[0]
597+
return xs.shape[0] if self.data_first_axis == 'T' else xs.shape[1]
597598
else:
598599
leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.Array))
599-
if isinstance(self.target.mode, bm.BatchingMode):
600-
if self.time_major:
601-
num_steps = [x.shape[0] for x in leaves]
602-
else:
603-
num_steps = [x.shape[1] for x in leaves]
604-
else:
600+
if self.data_first_axis == 'T':
605601
num_steps = [x.shape[0] for x in leaves]
602+
else:
603+
num_steps = [x.shape[1] for x in leaves]
606604
if len(set(num_steps)) != 1:
607605
raise ValueError(f'Number of time step is different across arrays in '
608606
f'the provided "xs". We got {set(num_steps)}.')

brainpy/dyn/transform.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ class LoopOverTime(DynSysToBPObj):
7171
>>> over_time.reset_state(n_batch)
7272
(30, 128, 2)
7373
>>>
74-
>>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in), time_major=True)
74+
>>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T')
7575
>>> print(hist_l3.shape)
7676
>>>
7777
>>> # monitor the "l1" layer state
7878
>>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state)
7979
>>> over_time.reset_state(n_batch)
80-
>>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in), time_major=True)
80+
>>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T')
8181
>>> print(hist_l3.shape)
8282
(30, 128, 2)
8383
>>> print(hist_l1.shape)
@@ -148,7 +148,7 @@ def __call__(
148148
t0: float = 0.,
149149
dt: Optional[float] = None,
150150
shared_arg: Optional[Dict] = None,
151-
time_major: bool = True
151+
data_first_axis: str = 'T'
152152
):
153153
"""Forward propagation along the time or inputs.
154154
@@ -164,7 +164,7 @@ def __call__(
164164
shared_arg: dict
165165
The shared arguments across the nodes.
166166
For instance, `shared_arg={'fit': False}` for the prediction phase.
167-
time_major: bool
167+
data_first_axis: str
168168
Denote whether the input data is time major.
169169
If so, we treat the data as `(time, batch, ...)` when the `target` is in Batching mode.
170170
Default is True.
@@ -174,6 +174,8 @@ def __call__(
174174
out: PyTree
175175
The accumulated outputs over time.
176176
"""
177+
assert data_first_axis in ['B', 'T']
178+
177179
is_float(t0, 't0')
178180
is_float(dt, 'dt', allow_none=True)
179181
dt = bm.get_dt() if dt is None else dt
@@ -194,11 +196,11 @@ def __call__(
194196
else:
195197
inp_err_msg = ('\n'
196198
'Input should be a Array PyTree with the shape '
197-
'of (B, T, ...) or (T, B, ...) with `time_major=True`, '
199+
'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, '
198200
'where B the batch size and T the time length.')
199201
xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.Array))
200202
if isinstance(self.target.mode, bm.BatchingMode):
201-
b_idx, t_idx = (1, 0) if time_major else (0, 1)
203+
b_idx, t_idx = (1, 0) if data_first_axis == 'T' else (0, 1)
202204

203205
try:
204206
batch = tuple(set([x.shape[b_idx] for x in xs]))
@@ -224,10 +226,10 @@ def __call__(
224226
if self.no_state:
225227
xs = [jnp.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs]
226228
else:
227-
if not time_major:
229+
if data_first_axis == 'B':
228230
xs = [jnp.moveaxis(x, 0, 1) for x in xs]
229231
xs = tree_unflatten(tree, xs)
230-
origin_shape = (length[0], batch[0]) if time_major else (batch[0], length[0])
232+
origin_shape = (length[0], batch[0]) if data_first_axis == 'T' else (batch[0], length[0])
231233

232234
else:
233235

brainpy/train/back_propagation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -491,17 +491,17 @@ def loss_fun(predicts, targets):
491491
Make the monitored results as NumPy arrays.
492492
logger: Any
493493
A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`.
494-
time_major: bool
495-
To indicate whether the first axis is the batch size (``time_major=False``) or the
496-
time length (``time_major=True``).
494+
data_first_axis: str
495+
To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the
496+
time length (``data_first_axis='T'``).
497497
"""
498498

499499
def _step_func_loss(self, shared_args, inputs, targets):
500500
num_step = self._get_input_time_step(xs=inputs)
501501
indices = jnp.arange(num_step, dtype=bm.int_)
502502
times = indices * self.dt + self.t0
503503
indices = indices + self.i0
504-
if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
504+
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
505505
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.Array))
506506
inputs = (times, indices, inputs)
507507
outs, mons = self._predict(xs=inputs, shared_args=shared_args)
@@ -535,7 +535,7 @@ def _step_func_train(self, shared_args, inputs, targets):
535535
return res[1:]
536536

537537
def _step_func_predict(self, shared, x=None):
538-
assert not self.time_major, f'There is no time dimension when using the trainer {self.__class__.__name__}.'
538+
assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.'
539539

540540
# input step
541541
self.target.clear_input()

brainpy/train/online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def fit(
171171
shared['t'] += self.t0
172172
shared['i'] += self.i0
173173

174-
if not self.time_major:
174+
if self.data_first_axis == 'B':
175175
xs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
176176
xs,
177177
is_leaf=lambda x: isinstance(x, bm.Array))

0 commit comments

Comments
 (0)