Skip to content

Commit 1e0fe97

Browse files
committed
Add success metric
1 parent 60cf843 commit 1e0fe97

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
216216
f'reward/{k}': 0.0
217217
for k in self._config.reward_config.reward_scales.keys()
218218
},
219+
'reward/success': jp.array(0.0),
220+
'reward/lifted': jp.array(0.0),
219221
}
220222

221223
info = {
@@ -333,9 +335,8 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
333335

334336
# Sparse rewards
335337
box_pos = data.xpos[self._obj_body]
336-
total_reward += (
337-
box_pos[2] > 0.05
338-
) * self._config.reward_config.lifted_reward
338+
lifted = (box_pos[2] > 0.05) * self._config.reward_config.lifted_reward
339+
total_reward += lifted
339340
success = self._get_success(data, state.info)
340341
total_reward += success * self._config.reward_config.success_reward
341342

@@ -352,6 +353,10 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
352353
out_of_bounds |= box_pos[2] < 0.0
353354
state.metrics.update(out_of_bounds=out_of_bounds.astype(float))
354355
state.metrics.update({f'reward/{k}': v for k, v in raw_rewards.items()})
356+
state.metrics.update({
357+
'reward/lifted': lifted.astype(float),
358+
'reward/success': success.astype(float),
359+
})
355360

356361
done = (
357362
out_of_bounds

0 commit comments

Comments
 (0)