diff --git a/ksim/task/ppo.py b/ksim/task/ppo.py index 32c77fe0..bf41d41f 100644 --- a/ksim/task/ppo.py +++ b/ksim/task/ppo.py @@ -89,7 +89,8 @@ def compute_gae_and_targets_for_sample( trunc_mask_t = jnp.where(successes_t, 1.0, 0.0) bootstrapped_rewards_t = rewards_t / rollout_length_steps + decay_gamma * values_t * trunc_mask_t - mask_t = jnp.where(dones_t, 0.0, 1.0) + is_failure = dones_t & ~successes_t + mask_t = jnp.where(is_failure, 0.0, 1.0) # Compute returns and GAE. deltas_t = bootstrapped_rewards_t + decay_gamma * values_shifted_t * mask_t - values_t