Skip to content

Commit cbba30f

Browse files
authored
better feet height reward (#507)
* better feet height reward * remove unused term
1 parent 76764ad commit cbba30f

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

ksim/rewards.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -746,9 +746,9 @@ def get_reward_stateful(
746746
not_moving_ang = trajectory.qvel[..., 5] < self.angvel_moving_threshold
747747
not_moving = not_moving_lin & not_moving_ang
748748

749-
sensor_data_tcn = trajectory.obs[self.contact_obs] > 0.5 # Values are either 0 or 1.
750-
sensor_data_tn = sensor_data_tcn.any(axis=-2)
751-
chex.assert_shape(sensor_data_tn, (..., self.num_feet))
749+
contact_tcn = trajectory.obs[self.contact_obs] > 0.5 # Values are either 0 or 1.
750+
contact_tn = contact_tcn.any(axis=-2)
751+
chex.assert_shape(contact_tn, (..., self.num_feet))
752752

753753
threshold_steps = round(self.period / self.ctrl_dt)
754754

@@ -764,7 +764,7 @@ def scan_fn(
764764
reward_carry, count_tn = xax.scan(
765765
scan_fn,
766766
reward_carry,
767-
(sensor_data_tn, not_moving, trajectory.done),
767+
(contact_tn, not_moving, trajectory.done),
768768
)
769769

770770
# Gradually increase reward until `threshold_steps`.
@@ -784,7 +784,6 @@ class FeetHeightReward(StatefulReward):
784784
position_obs: str = attrs.field()
785785
height: float = attrs.field()
786786
num_feet: int = attrs.field(default=2)
787-
bias: float = attrs.field(default=0.0)
788787
linvel_moving_threshold: float = attrs.field(default=0.05)
789788
angvel_moving_threshold: float = attrs.field(default=0.05)
790789

@@ -799,6 +798,10 @@ def get_reward_stateful(
799798
trajectory: Trajectory,
800799
reward_carry: tuple[Array, Array],
801800
) -> tuple[Array, tuple[Array, Array]]:
801+
not_moving_lin = jnp.linalg.norm(trajectory.qvel[..., :2], axis=-1) < self.linvel_moving_threshold
802+
not_moving_ang = trajectory.qvel[..., 5] < self.angvel_moving_threshold
803+
not_moving = not_moving_lin & not_moving_ang
804+
802805
contact_tcn = trajectory.obs[self.contact_obs] > 0.5 # Values are either 0 or 1.
803806
contact_tn = contact_tcn.any(axis=-2)
804807
chex.assert_shape(contact_tn, (..., self.num_feet))
@@ -809,17 +812,22 @@ def get_reward_stateful(
809812
# Give a sparse reward once the foot contacts the ground, equal to the
810813
# maximum height of the foot since the last contact, thresholded at the
811814
# target height.
812-
def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array]) -> tuple[tuple[Array, Array], Array]:
813-
(elapsed_time_n, max_height_n), (contact_n, position_n3) = carry, x
815+
def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array, Array]) -> tuple[tuple[Array, Array], Array]:
816+
(elapsed_time_n, max_height_n), (contact_n, position_n3, not_moving) = carry, x
814817
height_n = position_n3[..., 2]
815818
scale = (elapsed_time_n / self.period).clip(max=1.0)
816-
reward_n = jnp.where(contact_n, max_height_n, 0.0).clip(max=self.height) * scale
819+
reset = not_moving | contact_n
820+
reward_n = jnp.where(reset, max_height_n, 0.0).clip(max=self.height) * scale
817821
max_height_n = jnp.maximum(max_height_n, height_n)
818-
max_height_n = jnp.where(contact_n, 0.0, max_height_n)
819-
elapsed_time_n = jnp.where(contact_n, 0.0, elapsed_time_n + self.ctrl_dt)
822+
max_height_n = jnp.where(reset, 0.0, max_height_n)
823+
elapsed_time_n = jnp.where(reset, 0.0, elapsed_time_n + self.ctrl_dt)
820824
return (elapsed_time_n, max_height_n), reward_n
821825

822-
reward_carry, reward_tn = xax.scan(scan_fn, reward_carry, (contact_tn, position_tn3))
826+
reward_carry, reward_tn = xax.scan(
827+
scan_fn,
828+
reward_carry,
829+
(contact_tn, position_tn3, not_moving | trajectory.done),
830+
)
823831
reward_t = reward_tn.max(axis=-1)
824832
return reward_t, reward_carry
825833

0 commit comments

Comments
 (0)