@@ -385,20 +385,13 @@ def _get_amp_disc_loss_and_metrics(
385385
386386 sim_motions = self .trajectory_to_motion (trajectories )
387387
388- # Adds noise to the real and sim motions.
389- sim_rng , real_rng = jax .random .split (rng )
390- max_noise = self .config .amp_reference_noise
391- min_noise = max_noise * self .config .amp_reference_noise_min_multiplier
392- cur_level = carry .env_states .curriculum_state .level .mean ()
393- noise_level = max_noise - (max_noise - min_noise ) * cur_level
394- sim_motions = sim_motions + jax .random .normal (sim_rng , sim_motions .shape ) * noise_level
395- real_motions = real_motions + jax .random .normal (real_rng , real_motions .shape ) * noise_level
396-
397388 # Computes the discriminator loss.
398389 disc_fn = xax .vmap (self .call_discriminator , in_axes = (None , 0 , 0 ), jit_level = JitLevel .RL_CORE )
399390 real_disc_rng , sim_disc_rng = jax .random .split (rng )
400- real_disc_logits = disc_fn (model , real_motions , jax .random .split (real_disc_rng , real_motions .shape [0 ]))
401- sim_disc_logits = disc_fn (model , sim_motions , jax .random .split (sim_disc_rng , sim_motions .shape [0 ]))
391+ real_batch = jax .tree_util .tree_leaves (real_motions )[0 ].shape [0 ]
392+ sim_batch = jax .tree_util .tree_leaves (sim_motions )[0 ].shape [0 ]
393+ real_disc_logits = disc_fn (model , real_motions , jax .random .split (real_disc_rng , real_batch ))
394+ sim_disc_logits = disc_fn (model , sim_motions , jax .random .split (sim_disc_rng , sim_batch ))
402395 real_disc_loss , sim_disc_loss = self .get_disc_losses (real_disc_logits , sim_disc_logits )
403396
404397 disc_loss = real_disc_loss + sim_disc_loss
@@ -425,19 +418,6 @@ def _get_disc_metrics_and_grads(
425418 grads , metrics = loss_fn (model_arr , model_static , trajectories , real_motions , carry , rng )
426419 return metrics , grads
427420
428- @staticmethod
429- def _make_real_batch (
430- motions : Array ,
431- window_t : int ,
432- batch_b : int ,
433- rng : PRNGKeyArray ,
434- ) -> Array :
435- num_motions = motions .shape [0 ]
436- clip_rng , start_rng = jax .random .split (rng )
437- clip_idx = jax .random .randint (clip_rng , (batch_b ,), 0 , num_motions )
438- start_idx = jax .random .randint (start_rng , (batch_b ,), 0 , window_t )
439- return jax .vmap (_loop_slice , in_axes = (0 , 0 , None ))(motions [clip_idx ], start_idx , window_t )
440-
441421 @xax .jit (static_argnames = ["self" , "constants" ], jit_level = JitLevel .RL_CORE )
442422 def _single_step (
443423 self ,
@@ -534,3 +514,41 @@ def update_model(
534514 rewards = rewards ,
535515 rng = rng ,
536516 )
517+
518+ @staticmethod # ensure consistent calling convention
519+ def _make_real_batch (
520+ motions : PyTree ,
521+ window_t : int ,
522+ batch_b : int ,
523+ rng : PRNGKeyArray ,
524+ ) -> PyTree :
525+ """Sample a batch of windowed motion snippets from a PyTree of motions.
526+
527+ Args:
528+ motions: A PyTree whose leaves are arrays of shape (B, T, ...).
529+ window_t: Length of the temporal window to sample.
530+ batch_b: Number of windows to sample.
531+ rng: PRNG key used for sampling.
532+
533+ Returns:
534+ A PyTree with the same structure as ``motions`` whose leaves have
535+ shape (batch_b, window_t, ...).
536+ """
537+ num_motions = jax .tree_util .tree_leaves (motions )[0 ].shape [0 ]
538+
539+ keys = jax .random .split (rng , batch_b + 1 )
540+ clip_key , sample_keys = keys [0 ], keys [1 :]
541+
542+ # Sample which clip each element in the batch comes from.
543+ clip_idx = jax .random .randint (clip_key , (batch_b ,), 0 , num_motions )
544+
545+ batch_clips = jax .tree_util .tree_map (lambda arr : arr [clip_idx ], motions )
546+
547+ def _sample_single (clip : PyTree , rng_key : PRNGKeyArray ) -> PyTree :
548+ """Samples an unbiased window from a single motion clip."""
549+ # Length of the real clip (may differ across clips).
550+ t_real = jax .tree_util .tree_leaves (clip )[0 ].shape [0 ]
551+ start = jax .random .randint (rng_key , (), 0 , t_real ) # unbiased start index
552+ return jax .tree_util .tree_map (lambda arr : _loop_slice (arr , start , window_t ), clip )
553+
554+ return jax .vmap (_sample_single )(batch_clips , sample_keys )
0 commit comments