|
42 | 42 | import functools |
43 | 43 | import logging |
44 | 44 | from abc import ABC, abstractmethod |
45 | | -from typing import Collection, Literal, Self, final |
| 45 | +from typing import Collection, Literal, Mapping, Self, final |
46 | 46 |
|
47 | 47 | import attrs |
48 | 48 | import chex |
|
52 | 52 | from jaxtyping import Array, PRNGKeyArray, PyTree |
53 | 53 |
|
54 | 54 | from ksim.commands import EasyJoystickCommandValue, JoystickCommandValue, SinusoidalGaitCommandValue |
55 | | -from ksim.debugging import JitLevel |
56 | 55 | from ksim.types import PhysicsModel, Trajectory |
57 | 56 | from ksim.utils.mujoco import get_body_data_idx_from_name, get_qpos_data_idxs_by_name |
58 | 57 | from ksim.utils.validators import ( |
@@ -117,7 +116,7 @@ class Reward(ABC): |
117 | 116 | scale_by_curriculum: bool = attrs.field(default=False) |
118 | 117 |
|
119 | 118 | @abstractmethod |
120 | | - def get_reward(self, trajectory: Trajectory) -> Array: |
| 119 | + def get_reward(self, trajectory: Trajectory) -> Array | Mapping[str, Array]: |
121 | 120 | """Get the reward for a single trajectory. |
122 | 121 |
|
123 | 122 | Args: |
@@ -158,7 +157,11 @@ def initial_carry(self, rng: PRNGKeyArray) -> PyTree: |
158 | 157 | """ |
159 | 158 |
|
160 | 159 | @abstractmethod |
161 | | - def get_reward_stateful(self, trajectory: Trajectory, reward_carry: PyTree) -> tuple[Array, PyTree]: |
| 160 | + def get_reward_stateful( |
| 161 | + self, |
| 162 | + trajectory: Trajectory, |
| 163 | + reward_carry: PyTree, |
| 164 | + ) -> tuple[Array | Mapping[str, Array], PyTree]: |
162 | 165 | """Get the reward for a single trajectory. |
163 | 166 |
|
164 | 167 | This is the same as `get_reward`, but it also takes in the reward carry |
@@ -718,9 +721,11 @@ def _update_arrow(self, cmd_x: float, cmd_y: float) -> None: |
718 | 721 |
|
719 | 722 | def update(self, trajectory: Trajectory) -> None: |
720 | 723 | """Visualizes the joystick command target position and orientation.""" |
721 | | - cur_xvel, cur_yvel = trajectory.qvel[..., 0].item(), trajectory.qvel[..., 1].item() |
| 724 | + quat = JoystickReward.get_quat(trajectory) |
| 725 | + linvel = trajectory.qvel[..., :3] |
| 726 | + linvel = xax.rotate_vector_by_quat(linvel, quat, inverse=True) |
722 | 727 | self.pos = (0, 0, self.height) |
723 | | - self._update_arrow(cur_xvel, cur_yvel) |
| 728 | + self._update_arrow(linvel[..., 0].item(), linvel[..., 1].item()) |
724 | 729 |
|
725 | 730 | @classmethod |
726 | 731 | def get( |
@@ -752,42 +757,55 @@ class JoystickReward(Reward): |
752 | 757 | """Reward for following the joystick command.""" |
753 | 758 |
|
754 | 759 | command_name: str = attrs.field(default="joystick_command") |
755 | | - ang_penalty_ratio: float = attrs.field(default=2.0) |
| 760 | + dir_scale: float = attrs.field(default=1.0) |
| 761 | + mag_scale: float = attrs.field(default=1.0) |
| 762 | + yaw_scale: float = attrs.field(default=1.0) |
756 | 763 |
|
757 | | - @xax.jit(static_argnames=["self"], jit_level=JitLevel.UNROLL) |
758 | | - def get_reward(self, trajectory: Trajectory) -> Array: |
| 764 | + def get_reward(self, trajectory: Trajectory) -> dict[str, Array]: |
759 | 765 | if self.command_name not in trajectory.command: |
760 | 766 | raise ValueError(f"Command {self.command_name} not found! Ensure that it is in the task.") |
761 | 767 | return self._get_reward_for(trajectory.command[self.command_name], trajectory) |
762 | 768 |
|
763 | | - def _get_reward_for(self, joystick_cmd: JoystickCommandValue, trajectory: Trajectory) -> Array: |
| 769 | + @classmethod |
| 770 | + def get_quat(cls, trajectory: Trajectory) -> Array: |
| 771 | + quat = trajectory.qpos[..., 3:7] |
| 772 | + yaw = xax.quat_to_yaw(quat) |
| 773 | + zeros = jnp.zeros_like(yaw) |
| 774 | + euler = jnp.stack([zeros, zeros, yaw], axis=-1) |
| 775 | + quat = xax.euler_to_quat(euler) |
| 776 | + return quat |
| 777 | + |
| 778 | + def _get_reward_for(self, joystick_cmd: JoystickCommandValue, trajectory: Trajectory) -> dict[str, Array]: |
764 | 779 | # Gets the target X, Y, and Yaw velocities. |
765 | 780 | tgts = joystick_cmd.vels |
766 | 781 |
|
767 | | - # Smooths the target velocities. |
768 | | - trg_xvel, trg_yvel, trg_yawvel = tgts.T |
769 | | - |
770 | 782 | # Gets the robot's current velocities. |
771 | | - cur_xvel = trajectory.qvel[..., 0] |
772 | | - cur_yvel = trajectory.qvel[..., 1] |
773 | | - cur_yawvel = trajectory.qvel[..., 5] |
774 | | - |
775 | | - # Gets the robot's current yaw. |
776 | | - quat = trajectory.qpos[..., 3:7] |
777 | | - cur_yaw = xax.quat_to_yaw(quat) |
778 | | - |
779 | | - # Rotates the command X and Y velocities to the robot's current yaw. |
780 | | - trg_xvel_rot = trg_xvel * jnp.cos(cur_yaw) - trg_yvel * jnp.sin(cur_yaw) |
781 | | - trg_yvel_rot = trg_xvel * jnp.sin(cur_yaw) + trg_yvel * jnp.cos(cur_yaw) |
782 | | - |
783 | | - # Linear reward for tracking the target velocities. |
784 | | - pos_x_rew = -jnp.abs(trg_xvel_rot - cur_xvel) |
785 | | - pos_y_rew = -jnp.abs(trg_yvel_rot - cur_yvel) |
786 | | - rot_z_rew = -jnp.abs(trg_yawvel - cur_yawvel) |
787 | | - |
788 | | - reward = (pos_x_rew + pos_y_rew + rot_z_rew) / 3.0 |
789 | | - |
790 | | - return reward |
| 783 | + quat = self.get_quat(trajectory) |
| 784 | + linvel = trajectory.qvel[..., :3] |
| 785 | + linvel = xax.rotate_vector_by_quat(linvel, quat, inverse=True) |
| 786 | + yawvel = trajectory.qvel[..., 5] |
| 787 | + |
| 788 | + # Reward for tracking the direction (cosine similarity). |
| 789 | + cur_xy = linvel[..., :2] |
| 790 | + trg_xy = tgts[..., :2] |
| 791 | + cur_norm = jnp.linalg.norm(cur_xy, axis=-1) |
| 792 | + trg_norm = jnp.linalg.norm(trg_xy, axis=-1) |
| 793 | + denom_xy = cur_norm * trg_norm |
| 794 | + xy_cos_sim = (cur_xy * trg_xy).sum(axis=-1) / denom_xy.clip(min=1e-6) |
| 795 | + |
| 796 | + # Reward for tracking the magnitude, in the direction of the target. |
| 797 | + xy_mag_rew = 1.0 - jnp.where(trg_norm < 1e-6, cur_norm, jnp.abs(cur_norm - trg_norm) / trg_norm.clip(min=1e-6)) |
| 798 | + |
| 799 | + # Reward for tracking the yaw. |
| 800 | + cur_yaw = yawvel |
| 801 | + trg_yaw = tgts[..., 2] |
| 802 | + yaw_mag_rew = 1.0 - jnp.abs(cur_yaw - trg_yaw) |
| 803 | + |
| 804 | + return { |
| 805 | + "dir": xy_cos_sim * self.dir_scale, |
| 806 | + "mag": xy_mag_rew * self.mag_scale, |
| 807 | + "yaw": yaw_mag_rew * self.yaw_scale, |
| 808 | + } |
791 | 809 |
|
792 | 810 | def get_markers(self) -> Collection[Marker]: |
793 | 811 | return [JoystickRewardMarker.get()] |
@@ -1054,19 +1072,25 @@ class EasyJoystickReward(StatefulReward): |
1054 | 1072 | def initial_carry(self, rng: PRNGKeyArray) -> Array: |
1055 | 1073 | return self.airtime.initial_carry(rng) |
1056 | 1074 |
|
1057 | | - def get_reward_stateful(self, trajectory: Trajectory, reward_carry: Array) -> tuple[Array, Array]: |
| 1075 | + def get_reward_stateful(self, trajectory: Trajectory, reward_carry: Array) -> tuple[dict[str, Array], Array]: |
1058 | 1076 | if self.command_name not in trajectory.command: |
1059 | 1077 | raise ValueError(f"Command {self.command_name} not found! Ensure that it is in the task.") |
1060 | 1078 |
|
1061 | 1079 | cmd: EasyJoystickCommandValue = trajectory.command[self.command_name] |
1062 | | - joystick_reward = self.joystick._get_reward_for(cmd.joystick, trajectory) * self.joystick.scale |
1063 | | - gait_reward = self.gait._get_reward_for(cmd.gait, trajectory) * self.gait.scale |
| 1080 | + joystick_reward = self.joystick._get_reward_for(cmd.joystick, trajectory) |
| 1081 | + gait_reward = self.gait._get_reward_for(cmd.gait, trajectory) |
1064 | 1082 | airtime_reward, airtime_carry = self.airtime.get_reward_stateful(trajectory, reward_carry) |
1065 | 1083 |
|
1066 | 1084 | # Mask out airtime reward when the robot is not moving. |
1067 | 1085 | airtime_reward = jnp.where(cmd.joystick.command.argmax(axis=-1) == 0, 0.0, airtime_reward) |
1068 | 1086 |
|
1069 | | - total_reward = joystick_reward + gait_reward + airtime_reward * self.airtime.scale |
| 1087 | + total_reward = { |
| 1088 | + "gait": gait_reward * self.gait.scale, |
| 1089 | + "airtime": airtime_reward * self.airtime.scale, |
| 1090 | + } |
| 1091 | + for k, v in joystick_reward.items(): |
| 1092 | + total_reward[f"joystick/{k}"] = v * self.joystick.scale |
| 1093 | + |
1070 | 1094 | return total_reward, airtime_carry |
1071 | 1095 |
|
1072 | 1096 | def get_markers(self) -> Collection[Marker]: |
|
0 commit comments