Skip to content

Commit 76764ad

Browse files
authored
revive feet height reward (#506)
* revive feet height reward * better scale
1 parent 28a2ea8 commit 76764ad

File tree

2 files changed

+72
-36
lines changed

2 files changed

+72
-36
lines changed

examples/walking.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,19 +519,25 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward
519519
),
520520
"foot_airtime": ksim.FeetAirTimeReward(
521521
ctrl_dt=self.config.ctrl_dt,
522-
threshold=self.config.gait_period / 2.0,
522+
period=self.config.gait_period / 2.0,
523+
contact_obs="feet_contact",
524+
scale=1.0,
525+
),
526+
"foot_height": ksim.FeetHeightReward(
527+
ctrl_dt=self.config.ctrl_dt,
528+
period=self.config.gait_period / 2.0,
523529
contact_obs="feet_contact",
524530
position_obs="feet_position",
525531
height=self.config.max_foot_height,
526532
scale=1.0,
527533
),
528534
"foot_force": ksim.FeetForcePenalty(
529535
force_obs="feet_force",
530-
scale=-0.1,
536+
scale=-1e-6,
531537
),
532538
"foot_torque": ksim.FeetTorquePenalty(
533539
torque_obs="feet_torque",
534-
scale=-0.1,
540+
scale=-1e-7,
535541
),
536542
}
537543

ksim/rewards.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"LinkJerkPenalty",
2828
"ReachabilityPenalty",
2929
"FeetAirTimeReward",
30+
"FeetHeightReward",
3031
"FeetForcePenalty",
3132
"FeetTorquePenalty",
3233
"SinusoidalGaitReward",
@@ -47,11 +48,7 @@
4748
from ksim.commands import AngularVelocityCommandValue, LinearVelocityCommandValue, SinusoidalGaitCommandValue
4849
from ksim.types import PhysicsModel, Trajectory
4950
from ksim.utils.mujoco import get_body_data_idx_from_name, get_qpos_data_idxs_by_name
50-
from ksim.utils.validators import (
51-
CartesianIndex,
52-
cartesian_index_to_dim,
53-
norm_validator,
54-
)
51+
from ksim.utils.validators import CartesianIndex, cartesian_index_to_dim, norm_validator
5552
from ksim.vis import Marker
5653

5754
logger = logging.getLogger(__name__)
@@ -729,27 +726,22 @@ def get_reward(self, traj: Trajectory) -> jnp.ndarray:
729726
class FeetAirTimeReward(StatefulReward):
730727
"""Reward for feet either touching or not touching the ground for some time."""
731728

732-
threshold: float = attrs.field()
729+
period: float = attrs.field()
733730
ctrl_dt: float = attrs.field()
734731
contact_obs: str = attrs.field()
735-
position_obs: str = attrs.field()
736-
height: float = attrs.field()
737732
num_feet: int = attrs.field(default=2)
738733
bias: float = attrs.field(default=0.0)
739734
linvel_moving_threshold: float = attrs.field(default=0.05)
740735
angvel_moving_threshold: float = attrs.field(default=0.05)
741736

742-
def initial_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
743-
return (
744-
jnp.zeros(self.num_feet, dtype=jnp.int32),
745-
jnp.zeros(self.num_feet, dtype=jnp.float32),
746-
)
737+
def initial_carry(self, rng: PRNGKeyArray) -> Array:
738+
return jnp.zeros(self.num_feet, dtype=jnp.int32)
747739

748740
def get_reward_stateful(
749741
self,
750742
trajectory: Trajectory,
751-
reward_carry: tuple[Array, Array],
752-
) -> tuple[Array, tuple[Array, Array]]:
743+
reward_carry: Array,
744+
) -> tuple[Array, Array]:
753745
not_moving_lin = jnp.linalg.norm(trajectory.qvel[..., :2], axis=-1) < self.linvel_moving_threshold
754746
not_moving_ang = trajectory.qvel[..., 5] < self.angvel_moving_threshold
755747
not_moving = not_moving_lin & not_moving_ang
@@ -758,39 +750,77 @@ def get_reward_stateful(
758750
sensor_data_tn = sensor_data_tcn.any(axis=-2)
759751
chex.assert_shape(sensor_data_tn, (..., self.num_feet))
760752

761-
position_tn3 = trajectory.obs[self.position_obs]
762-
chex.assert_shape(position_tn3, (..., self.num_feet, 3))
763-
764-
threshold_steps = round(self.threshold / self.ctrl_dt)
753+
threshold_steps = round(self.period / self.ctrl_dt)
765754

766755
def scan_fn(
767-
carry: tuple[Array, Array],
768-
x: tuple[Array, Array, Array, Array],
769-
) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
770-
(count_n, max_height_n), (contact_n, position_n3, not_moving, done) = carry, x
756+
carry: Array,
757+
x: tuple[Array, Array, Array],
758+
) -> tuple[Array, Array]:
759+
count_n, (contact_n, not_moving, done) = carry, x
771760
reset = done | not_moving | contact_n
772761
count_n = jnp.where(reset, 0, count_n + 1)
762+
return count_n, count_n
773763

774-
height_n = position_n3[..., 2]
775-
max_height_n = jnp.where(reset, 0.0, jnp.maximum(max_height_n, height_n))
776-
777-
return (count_n, max_height_n), (count_n, max_height_n)
778-
779-
reward_carry, (count_tn, max_height_tn) = xax.scan(
764+
reward_carry, count_tn = xax.scan(
780765
scan_fn,
781766
reward_carry,
782-
(sensor_data_tn, position_tn3, not_moving, trajectory.done),
767+
(sensor_data_tn, not_moving, trajectory.done),
783768
)
784769

785770
# Gradually increase reward until `threshold_steps`.
786771
reward_tn = (count_tn.astype(jnp.float32) / threshold_steps) + self.bias
787772
reward_tn = jnp.where((count_tn > 0) & (count_tn < threshold_steps), reward_tn, 0.0)
773+
reward_t = reward_tn.sum(axis=-1)
774+
return reward_t, reward_carry
788775

789-
# Scale the reward according to the max height.
790-
reward_tn = reward_tn * max_height_tn.clip(max=self.height) / self.height
791776

792-
reward_t = reward_tn.sum(axis=-1)
777+
@attrs.define(frozen=True, kw_only=True)
778+
class FeetHeightReward(StatefulReward):
779+
"""Reward for feet either touching or not touching the ground for some time."""
780+
781+
period: float = attrs.field()
782+
ctrl_dt: float = attrs.field()
783+
contact_obs: str = attrs.field()
784+
position_obs: str = attrs.field()
785+
height: float = attrs.field()
786+
num_feet: int = attrs.field(default=2)
787+
bias: float = attrs.field(default=0.0)
788+
linvel_moving_threshold: float = attrs.field(default=0.05)
789+
angvel_moving_threshold: float = attrs.field(default=0.05)
790+
791+
def initial_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
792+
return (
793+
jnp.zeros(self.num_feet, dtype=jnp.float32),
794+
jnp.zeros(self.num_feet, dtype=jnp.float32),
795+
)
793796

797+
def get_reward_stateful(
798+
self,
799+
trajectory: Trajectory,
800+
reward_carry: tuple[Array, Array],
801+
) -> tuple[Array, tuple[Array, Array]]:
802+
contact_tcn = trajectory.obs[self.contact_obs] > 0.5 # Values are either 0 or 1.
803+
contact_tn = contact_tcn.any(axis=-2)
804+
chex.assert_shape(contact_tn, (..., self.num_feet))
805+
806+
position_tn3 = trajectory.obs[self.position_obs]
807+
chex.assert_shape(position_tn3, (..., self.num_feet, 3))
808+
809+
# Give a sparse reward once the foot contacts the ground, equal to the
810+
# maximum height of the foot since the last contact, thresholded at the
811+
# target height.
812+
def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array]) -> tuple[tuple[Array, Array], Array]:
813+
(elapsed_time_n, max_height_n), (contact_n, position_n3) = carry, x
814+
height_n = position_n3[..., 2]
815+
scale = (elapsed_time_n / self.period).clip(max=1.0)
816+
reward_n = jnp.where(contact_n, max_height_n, 0.0).clip(max=self.height) * scale
817+
max_height_n = jnp.maximum(max_height_n, height_n)
818+
max_height_n = jnp.where(contact_n, 0.0, max_height_n)
819+
elapsed_time_n = jnp.where(contact_n, 0.0, elapsed_time_n + self.ctrl_dt)
820+
return (elapsed_time_n, max_height_n), reward_n
821+
822+
reward_carry, reward_tn = xax.scan(scan_fn, reward_carry, (contact_tn, position_tn3))
823+
reward_t = reward_tn.max(axis=-1)
794824
return reward_t, reward_carry
795825

796826

0 commit comments

Comments
 (0)