Skip to content

Commit a459739

Browse files
authored
Update AMP flow to work with pytrees (#449)
* remove default added noise * update batching * fix clip sampling * bump version to 0.1.98
1 parent 7d5f7f8 commit a459739

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

ksim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Defines the main ksim API."""
22

3-
__version__ = "0.1.97"
3+
__version__ = "0.1.98"
44

55
from .actuators import *
66
from .commands import *

ksim/task/amp.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -385,20 +385,13 @@ def _get_amp_disc_loss_and_metrics(
385385

386386
sim_motions = self.trajectory_to_motion(trajectories)
387387

388-
# Adds noise to the real and sim motions.
389-
sim_rng, real_rng = jax.random.split(rng)
390-
max_noise = self.config.amp_reference_noise
391-
min_noise = max_noise * self.config.amp_reference_noise_min_multiplier
392-
cur_level = carry.env_states.curriculum_state.level.mean()
393-
noise_level = max_noise - (max_noise - min_noise) * cur_level
394-
sim_motions = sim_motions + jax.random.normal(sim_rng, sim_motions.shape) * noise_level
395-
real_motions = real_motions + jax.random.normal(real_rng, real_motions.shape) * noise_level
396-
397388
# Computes the discriminator loss.
398389
disc_fn = xax.vmap(self.call_discriminator, in_axes=(None, 0, 0), jit_level=JitLevel.RL_CORE)
399390
real_disc_rng, sim_disc_rng = jax.random.split(rng)
400-
real_disc_logits = disc_fn(model, real_motions, jax.random.split(real_disc_rng, real_motions.shape[0]))
401-
sim_disc_logits = disc_fn(model, sim_motions, jax.random.split(sim_disc_rng, sim_motions.shape[0]))
391+
real_batch = jax.tree_util.tree_leaves(real_motions)[0].shape[0]
392+
sim_batch = jax.tree_util.tree_leaves(sim_motions)[0].shape[0]
393+
real_disc_logits = disc_fn(model, real_motions, jax.random.split(real_disc_rng, real_batch))
394+
sim_disc_logits = disc_fn(model, sim_motions, jax.random.split(sim_disc_rng, sim_batch))
402395
real_disc_loss, sim_disc_loss = self.get_disc_losses(real_disc_logits, sim_disc_logits)
403396

404397
disc_loss = real_disc_loss + sim_disc_loss
@@ -425,19 +418,6 @@ def _get_disc_metrics_and_grads(
425418
grads, metrics = loss_fn(model_arr, model_static, trajectories, real_motions, carry, rng)
426419
return metrics, grads
427420

428-
@staticmethod
429-
def _make_real_batch(
430-
motions: Array,
431-
window_t: int,
432-
batch_b: int,
433-
rng: PRNGKeyArray,
434-
) -> Array:
435-
num_motions = motions.shape[0]
436-
clip_rng, start_rng = jax.random.split(rng)
437-
clip_idx = jax.random.randint(clip_rng, (batch_b,), 0, num_motions)
438-
start_idx = jax.random.randint(start_rng, (batch_b,), 0, window_t)
439-
return jax.vmap(_loop_slice, in_axes=(0, 0, None))(motions[clip_idx], start_idx, window_t)
440-
441421
@xax.jit(static_argnames=["self", "constants"], jit_level=JitLevel.RL_CORE)
442422
def _single_step(
443423
self,
@@ -534,3 +514,41 @@ def update_model(
534514
rewards=rewards,
535515
rng=rng,
536516
)
517+
518+
@staticmethod # ensure consistent calling convention
519+
def _make_real_batch(
520+
motions: PyTree,
521+
window_t: int,
522+
batch_b: int,
523+
rng: PRNGKeyArray,
524+
) -> PyTree:
525+
"""Sample a batch of windowed motion snippets from a PyTree of motions.
526+
527+
Args:
528+
motions: A PyTree whose leaves are arrays of shape (B, T, ...).
529+
window_t: Length of the temporal window to sample.
530+
batch_b: Number of windows to sample.
531+
rng: PRNG key used for sampling.
532+
533+
Returns:
534+
A PyTree with the same structure as ``motions`` whose leaves have
535+
shape (batch_b, window_t, ...).
536+
"""
537+
num_motions = jax.tree_util.tree_leaves(motions)[0].shape[0]
538+
539+
keys = jax.random.split(rng, batch_b + 1)
540+
clip_key, sample_keys = keys[0], keys[1:]
541+
542+
# Sample which clip each element in the batch comes from.
543+
clip_idx = jax.random.randint(clip_key, (batch_b,), 0, num_motions)
544+
545+
batch_clips = jax.tree_util.tree_map(lambda arr: arr[clip_idx], motions)
546+
547+
def _sample_single(clip: PyTree, rng_key: PRNGKeyArray) -> PyTree:
548+
"""Samples an unbiased window from a single motion clip."""
549+
# Length of the real clip (may differ across clips).
550+
t_real = jax.tree_util.tree_leaves(clip)[0].shape[0]
551+
start = jax.random.randint(rng_key, (), 0, t_real) # unbiased start index
552+
return jax.tree_util.tree_map(lambda arr: _loop_slice(arr, start, window_t), clip)
553+
554+
return jax.vmap(_sample_single)(batch_clips, sample_keys)

0 commit comments

Comments
 (0)