Skip to content

Commit 06c92a8

Browse files
authored
reward updates (#497)
* reward updates * multiple reward outputs * training fixes * nit changes * joysticl * oh man, much better * tweaked rewards * reward changes * norm change
1 parent 9a8212a commit 06c92a8

File tree

4 files changed

+84
-73
lines changed

4 files changed

+84
-73
lines changed

examples/walking.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,8 @@ def get_mujoco_model(self) -> mujoco.MjModel: # pyright: ignore[reportAttribute
359359
def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata: # pyright: ignore[reportAttributeAccessIssue]
360360
return ksim.Metadata.from_model(
361361
mj_model,
362-
kp=10.0,
363-
kd=0.1,
362+
kp=50.0,
363+
kd=1.0,
364364
)
365365

366366
def get_actuators(
@@ -461,7 +461,7 @@ def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
461461
gait_period=self.config.gait_period,
462462
ctrl_dt=self.config.ctrl_dt,
463463
max_height=self.config.max_foot_height,
464-
height_offset=0.04,
464+
height_offset=0.08,
465465
),
466466
joystick=ksim.JoystickCommand(
467467
run_speed=self.config.target_linear_velocity,
@@ -470,15 +470,12 @@ def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
470470
rotation_speed=self.config.target_angular_velocity,
471471
),
472472
),
473-
ksim.BaseHeightCommand(
474-
min_height=0.9,
475-
max_height=1.4,
476-
),
477473
]
478474

479475
def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
480476
return [
481477
ksim.StayAliveReward(scale=100.0),
478+
ksim.UprightReward(scale=5.0),
482479
ksim.EasyJoystickReward(
483480
gait=ksim.SinusoidalGaitReward(
484481
scale=5.0,
@@ -492,7 +489,6 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
492489
scale=1.0,
493490
),
494491
),
495-
ksim.BaseHeightTrackingReward(scale=5.0),
496492
]
497493

498494
def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:
@@ -513,8 +509,8 @@ def get_model(self, params: ksim.InitParams) -> Model:
513509
return Model(
514510
params.key,
515511
physics_model=params.physics_model,
516-
num_actor_inputs=50,
517-
num_critic_inputs=334,
512+
num_actor_inputs=49,
513+
num_critic_inputs=336,
518514
num_joints=17,
519515
min_std=0.01,
520516
max_std=1.0,
@@ -557,16 +553,13 @@ def run_actor(
557553
# Phase is required in order to follow the gait command.
558554
gait_phase_1 = sgj_cmd.gait.phase[..., None]
559555

560-
base_height_1 = commands["base_height_command"][..., None]
561-
562556
obs_n = jnp.concatenate(
563557
[
564558
dh_joint_pos_j, # NUM_JOINTS
565559
dh_joint_vel_j / 10.0, # NUM_JOINTS
566560
proj_grav_3, # 3
567561
imu_gyro_3, # 3
568562
gait_phase_1, # 1
569-
base_height_1, # 1
570563
joystick_cmd_ohe_8, # 8
571564
],
572565
axis=-1,
@@ -597,14 +590,13 @@ def run_critic(
597590
# Sinusoidal gait joystick command.
598591
sgj_cmd: ksim.EasyJoystickCommandValue = commands["easy_joystick_command"]
599592
joystick_cmd_ohe_8 = sgj_cmd.joystick.command
593+
joystick_vel_tgts_3 = sgj_cmd.joystick.vels
600594

601595
# Foot height difference.
602596
foot_height_2 = observations["feet_position_observation"][..., 2]
603597
foot_tgt_height_2 = sgj_cmd.gait.height
604598
foot_height_diff_2 = foot_height_2 - foot_tgt_height_2
605599

606-
base_height_1 = commands["base_height_command"][..., None]
607-
608600
obs_n = jnp.concatenate(
609601
[
610602
dh_joint_pos_j, # NUM_JOINTS
@@ -618,8 +610,8 @@ def run_critic(
618610
lin_vel_obs_3, # 3
619611
ang_vel_obs_3, # 3
620612
foot_height_diff_2, # 2
621-
base_height_1, # 1
622613
joystick_cmd_ohe_8, # 8
614+
joystick_vel_tgts_3, # 3
623615
],
624616
axis=-1,
625617
)

ksim/commands.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,13 @@ def _update_for(self, cmd: JoystickCommandValue, trajectory: Trajectory) -> None
274274
self.rgba = (r, g, b, 1.0)
275275

276276
cmd_x, cmd_y = cmd_vel[..., 0], cmd_vel[..., 1]
277-
278-
# Gets the robot's current yaw.
279-
quat = trajectory.qpos[..., 3:7]
280-
cur_yaw = xax.quat_to_yaw(quat)
281-
282-
# Rotates the command X and Y velocities to the robot's current yaw.
283-
cmd_x_rot = cmd_x * jnp.cos(cur_yaw) - cmd_y * jnp.sin(cur_yaw)
284-
cmd_y_rot = cmd_x * jnp.sin(cur_yaw) + cmd_y * jnp.cos(cur_yaw)
285-
286277
self.pos = (0, 0, self.height)
287278

288279
match cmd_idx:
289280
case 0:
290281
self._update_circle()
291282
case 1 | 2 | 3 | 6 | 7:
292-
self._update_arrow(cmd_x_rot.item(), cmd_y_rot.item())
283+
self._update_arrow(cmd_x.item(), cmd_y.item())
293284
case 4 | 5:
294285
self._update_cylinder()
295286
case _:
@@ -359,7 +350,7 @@ class JoystickCommand(Command):
359350
marker_z_offset: float = attrs.field(default=0.5)
360351
switch_prob: float = attrs.field(default=0.005)
361352

362-
def _get_vel_tgts(self, physics_data: PhysicsData, command: Array) -> Array:
353+
def _get_vel_tgts(self, command: Array) -> Array:
363354
# Gets the target X, Y, and Yaw targets.
364355
cmd_tgts = jnp.array(
365356
[
@@ -385,9 +376,9 @@ def initial_command(
385376
curriculum_level: Array,
386377
rng: PRNGKeyArray,
387378
) -> JoystickCommandValue:
388-
command = jax.random.choice(rng, jnp.arange(len(self.sample_probs)), p=jnp.array(self.sample_probs))
389-
command_ohe = jax.nn.one_hot(command, num_classes=8)
390-
vel_tgts = self._get_vel_tgts(physics_data, command)
379+
command = jax.random.choice(rng, len(self.sample_probs), p=jnp.array(self.sample_probs))
380+
command_ohe = jax.nn.one_hot(command, num_classes=len(self.sample_probs))
381+
vel_tgts = self._get_vel_tgts(command)
391382
return JoystickCommandValue(
392383
command=command_ohe,
393384
vels=vel_tgts,

ksim/rewards.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import functools
4343
import logging
4444
from abc import ABC, abstractmethod
45-
from typing import Collection, Literal, Self, final
45+
from typing import Collection, Literal, Mapping, Self, final
4646

4747
import attrs
4848
import chex
@@ -52,7 +52,6 @@
5252
from jaxtyping import Array, PRNGKeyArray, PyTree
5353

5454
from ksim.commands import EasyJoystickCommandValue, JoystickCommandValue, SinusoidalGaitCommandValue
55-
from ksim.debugging import JitLevel
5655
from ksim.types import PhysicsModel, Trajectory
5756
from ksim.utils.mujoco import get_body_data_idx_from_name, get_qpos_data_idxs_by_name
5857
from ksim.utils.validators import (
@@ -117,7 +116,7 @@ class Reward(ABC):
117116
scale_by_curriculum: bool = attrs.field(default=False)
118117

119118
@abstractmethod
120-
def get_reward(self, trajectory: Trajectory) -> Array:
119+
def get_reward(self, trajectory: Trajectory) -> Array | Mapping[str, Array]:
121120
"""Get the reward for a single trajectory.
122121
123122
Args:
@@ -158,7 +157,11 @@ def initial_carry(self, rng: PRNGKeyArray) -> PyTree:
158157
"""
159158

160159
@abstractmethod
161-
def get_reward_stateful(self, trajectory: Trajectory, reward_carry: PyTree) -> tuple[Array, PyTree]:
160+
def get_reward_stateful(
161+
self,
162+
trajectory: Trajectory,
163+
reward_carry: PyTree,
164+
) -> tuple[Array | Mapping[str, Array], PyTree]:
162165
"""Get the reward for a single trajectory.
163166
164167
This is the same as `get_reward`, but it also takes in the reward carry
@@ -718,9 +721,11 @@ def _update_arrow(self, cmd_x: float, cmd_y: float) -> None:
718721

719722
def update(self, trajectory: Trajectory) -> None:
720723
"""Visualizes the joystick command target position and orientation."""
721-
cur_xvel, cur_yvel = trajectory.qvel[..., 0].item(), trajectory.qvel[..., 1].item()
724+
quat = JoystickReward.get_quat(trajectory)
725+
linvel = trajectory.qvel[..., :3]
726+
linvel = xax.rotate_vector_by_quat(linvel, quat, inverse=True)
722727
self.pos = (0, 0, self.height)
723-
self._update_arrow(cur_xvel, cur_yvel)
728+
self._update_arrow(linvel[..., 0].item(), linvel[..., 1].item())
724729

725730
@classmethod
726731
def get(
@@ -752,42 +757,55 @@ class JoystickReward(Reward):
752757
"""Reward for following the joystick command."""
753758

754759
command_name: str = attrs.field(default="joystick_command")
755-
ang_penalty_ratio: float = attrs.field(default=2.0)
760+
dir_scale: float = attrs.field(default=1.0)
761+
mag_scale: float = attrs.field(default=1.0)
762+
yaw_scale: float = attrs.field(default=1.0)
756763

757-
@xax.jit(static_argnames=["self"], jit_level=JitLevel.UNROLL)
758-
def get_reward(self, trajectory: Trajectory) -> Array:
764+
def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
759765
if self.command_name not in trajectory.command:
760766
raise ValueError(f"Command {self.command_name} not found! Ensure that it is in the task.")
761767
return self._get_reward_for(trajectory.command[self.command_name], trajectory)
762768

763-
def _get_reward_for(self, joystick_cmd: JoystickCommandValue, trajectory: Trajectory) -> Array:
769+
@classmethod
770+
def get_quat(cls, trajectory: Trajectory) -> Array:
771+
quat = trajectory.qpos[..., 3:7]
772+
yaw = xax.quat_to_yaw(quat)
773+
zeros = jnp.zeros_like(yaw)
774+
euler = jnp.stack([zeros, zeros, yaw], axis=-1)
775+
quat = xax.euler_to_quat(euler)
776+
return quat
777+
778+
def _get_reward_for(self, joystick_cmd: JoystickCommandValue, trajectory: Trajectory) -> dict[str, Array]:
764779
# Gets the target X, Y, and Yaw velocities.
765780
tgts = joystick_cmd.vels
766781

767-
# Smooths the target velocities.
768-
trg_xvel, trg_yvel, trg_yawvel = tgts.T
769-
770782
# Gets the robot's current velocities.
771-
cur_xvel = trajectory.qvel[..., 0]
772-
cur_yvel = trajectory.qvel[..., 1]
773-
cur_yawvel = trajectory.qvel[..., 5]
774-
775-
# Gets the robot's current yaw.
776-
quat = trajectory.qpos[..., 3:7]
777-
cur_yaw = xax.quat_to_yaw(quat)
778-
779-
# Rotates the command X and Y velocities to the robot's current yaw.
780-
trg_xvel_rot = trg_xvel * jnp.cos(cur_yaw) - trg_yvel * jnp.sin(cur_yaw)
781-
trg_yvel_rot = trg_xvel * jnp.sin(cur_yaw) + trg_yvel * jnp.cos(cur_yaw)
782-
783-
# Linear reward for tracking the target velocities.
784-
pos_x_rew = -jnp.abs(trg_xvel_rot - cur_xvel)
785-
pos_y_rew = -jnp.abs(trg_yvel_rot - cur_yvel)
786-
rot_z_rew = -jnp.abs(trg_yawvel - cur_yawvel)
787-
788-
reward = (pos_x_rew + pos_y_rew + rot_z_rew) / 3.0
789-
790-
return reward
783+
quat = self.get_quat(trajectory)
784+
linvel = trajectory.qvel[..., :3]
785+
linvel = xax.rotate_vector_by_quat(linvel, quat, inverse=True)
786+
yawvel = trajectory.qvel[..., 5]
787+
788+
# Reward for tracking the direction (cosine similarity).
789+
cur_xy = linvel[..., :2]
790+
trg_xy = tgts[..., :2]
791+
cur_norm = jnp.linalg.norm(cur_xy, axis=-1)
792+
trg_norm = jnp.linalg.norm(trg_xy, axis=-1)
793+
denom_xy = cur_norm * trg_norm
794+
xy_cos_sim = (cur_xy * trg_xy).sum(axis=-1) / denom_xy.clip(min=1e-6)
795+
796+
# Reward for tracking the magnitude, in the direction of the target.
797+
xy_mag_rew = 1.0 - jnp.where(trg_norm < 1e-6, cur_norm, jnp.abs(cur_norm - trg_norm) / trg_norm.clip(min=1e-6))
798+
799+
# Reward for tracking the yaw.
800+
cur_yaw = yawvel
801+
trg_yaw = tgts[..., 2]
802+
yaw_mag_rew = 1.0 - jnp.abs(cur_yaw - trg_yaw)
803+
804+
return {
805+
"dir": xy_cos_sim * self.dir_scale,
806+
"mag": xy_mag_rew * self.mag_scale,
807+
"yaw": yaw_mag_rew * self.yaw_scale,
808+
}
791809

792810
def get_markers(self) -> Collection[Marker]:
793811
return [JoystickRewardMarker.get()]
@@ -1054,19 +1072,25 @@ class EasyJoystickReward(StatefulReward):
10541072
def initial_carry(self, rng: PRNGKeyArray) -> Array:
10551073
return self.airtime.initial_carry(rng)
10561074

1057-
def get_reward_stateful(self, trajectory: Trajectory, reward_carry: Array) -> tuple[Array, Array]:
1075+
def get_reward_stateful(self, trajectory: Trajectory, reward_carry: Array) -> tuple[dict[str, Array], Array]:
10581076
if self.command_name not in trajectory.command:
10591077
raise ValueError(f"Command {self.command_name} not found! Ensure that it is in the task.")
10601078

10611079
cmd: EasyJoystickCommandValue = trajectory.command[self.command_name]
1062-
joystick_reward = self.joystick._get_reward_for(cmd.joystick, trajectory) * self.joystick.scale
1063-
gait_reward = self.gait._get_reward_for(cmd.gait, trajectory) * self.gait.scale
1080+
joystick_reward = self.joystick._get_reward_for(cmd.joystick, trajectory)
1081+
gait_reward = self.gait._get_reward_for(cmd.gait, trajectory)
10641082
airtime_reward, airtime_carry = self.airtime.get_reward_stateful(trajectory, reward_carry)
10651083

10661084
# Mask out airtime reward when the robot is not moving.
10671085
airtime_reward = jnp.where(cmd.joystick.command.argmax(axis=-1) == 0, 0.0, airtime_reward)
10681086

1069-
total_reward = joystick_reward + gait_reward + airtime_reward * self.airtime.scale
1087+
total_reward = {
1088+
"gait": gait_reward * self.gait.scale,
1089+
"airtime": airtime_reward * self.airtime.scale,
1090+
}
1091+
for k, v in joystick_reward.items():
1092+
total_reward[f"joystick/{k}"] = v * self.joystick.scale
1093+
10701094
return total_reward, airtime_carry
10711095

10721096
def get_markers(self) -> Collection[Marker]:

ksim/task/rl.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pathlib import Path
2727
from threading import Thread
2828
from types import FrameType
29-
from typing import Any, Callable, Collection, Dict, Generic, TypeVar, cast
29+
from typing import Any, Callable, Collection, Dict, Generic, Mapping, TypeVar, cast
3030

3131
import chex
3232
import equinox as eqx
@@ -210,14 +210,18 @@ def get_rewards(
210210
reward_val, reward_carry = reward.get_reward_stateful(trajectory, reward_carry)
211211
else:
212212
reward_val = reward.get_reward(trajectory)
213-
reward_val = reward_val * reward.scale
213+
if isinstance(reward_val, Mapping):
214+
reward_val = {f"{reward_name}/{k}": v * reward.scale for k, v in reward_val.items()}
215+
else:
216+
reward_val = {reward_name: reward_val * reward.scale}
214217
if reward.scale_by_curriculum:
215-
reward_val = reward_val * curriculum_level
218+
reward_val = {k: v * curriculum_level for k, v in reward_val.items()}
216219

217-
if reward_val.shape != trajectory.done.shape:
218-
raise AssertionError(f"Reward {reward_name} shape {reward_val.shape} does not match {target_shape}")
220+
for k, v in reward_val.items():
221+
if v.shape != trajectory.done.shape:
222+
raise AssertionError(f"Reward {k} shape {v.shape} does not match {target_shape}")
219223

220-
reward_dict[reward_name] = reward_val
224+
reward_dict.update(reward_val)
221225
next_reward_carry[reward_name] = reward_carry
222226

223227
total_reward = jax.tree.reduce(jnp.add, list(reward_dict.values()))

0 commit comments

Comments
 (0)