Skip to content

Commit f035f9f

Browse files
authored
better force penalty (#508)
* better force penalty * typo * fix force penalty * comment * corrected * formatting * remove torque
1 parent b26ee9b commit f035f9f

File tree

2 files changed

+40
-47
lines changed

2 files changed

+40
-47
lines changed

examples/walking.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,6 @@ def get_observations(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.O
481481
foot_left_body_name="foot_left",
482482
foot_right_body_name="foot_right",
483483
),
484-
"feet_torque": ksim.FeetTorqueObservation.create(
485-
physics_model=physics_model,
486-
foot_left_body_name="foot_left",
487-
foot_right_body_name="foot_right",
488-
),
489484
}
490485

491486
def get_commands(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Command]:
@@ -524,20 +519,16 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward
524519
scale=1.0,
525520
),
526521
"foot_height": ksim.FeetHeightReward(
527-
ctrl_dt=self.config.ctrl_dt,
528-
period=self.config.gait_period / 2.0,
529522
contact_obs="feet_contact",
530523
position_obs="feet_position",
531524
height=self.config.max_foot_height,
532525
scale=1.0,
533526
),
534-
"foot_force": ksim.FeetForcePenalty(
527+
"foot_force": ksim.ForcePenalty(
535528
force_obs="feet_force",
536-
scale=-1e-6,
537-
),
538-
"foot_torque": ksim.FeetTorquePenalty(
539-
torque_obs="feet_torque",
540-
scale=-1e-7,
529+
ctrl_dt=self.config.ctrl_dt,
530+
bias=100.0,
531+
scale=-1e-1,
541532
),
542533
}
543534

@@ -560,7 +551,7 @@ def get_model(self, params: ksim.InitParams) -> Model:
560551
params.key,
561552
physics_model=params.physics_model,
562553
num_actor_inputs=43,
563-
num_critic_inputs=346,
554+
num_critic_inputs=340,
564555
num_joints=17,
565556
min_std=0.01,
566557
max_std=1.0,
@@ -638,11 +629,9 @@ def run_critic(
638629
lin_vel_obs_3 = observations["base_linear_velocity"]
639630
ang_vel_obs_3 = observations["base_angular_velocity"]
640631
feet_force_obs_23 = observations["feet_force"]
641-
feet_torque_obs_23 = observations["feet_torque"]
642632

643633
# Flattens the last two dimensions.
644634
feet_force_obs_6 = feet_force_obs_23.reshape(*feet_force_obs_23.shape[:-2], 6)
645-
feet_torque_obs_6 = feet_torque_obs_23.reshape(*feet_torque_obs_23.shape[:-2], 6)
646635

647636
# Command tensors.
648637
linvel_cmd: ksim.LinearVelocityCommandValue = commands["linvel"]
@@ -667,7 +656,6 @@ def run_critic(
667656
lin_vel_obs_3, # 3
668657
ang_vel_obs_3, # 3
669658
feet_force_obs_6, # 6
670-
feet_torque_obs_6, # 6
671659
linvel_cmd_4, # 4
672660
angvel_cmd_1, # 1
673661
],

ksim/rewards.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
"ReachabilityPenalty",
2929
"FeetAirTimeReward",
3030
"FeetHeightReward",
31-
"FeetForcePenalty",
32-
"FeetTorquePenalty",
31+
"ForcePenalty",
3332
"SinusoidalGaitReward",
3433
"BaseHeightTrackingReward",
3534
]
@@ -778,26 +777,21 @@ def scan_fn(
778777
class FeetHeightReward(StatefulReward):
779778
"""Reward for feet either touching or not touching the ground for some time."""
780779

781-
period: float = attrs.field()
782-
ctrl_dt: float = attrs.field()
783780
contact_obs: str = attrs.field()
784781
position_obs: str = attrs.field()
785782
height: float = attrs.field()
786783
num_feet: int = attrs.field(default=2)
787784
linvel_moving_threshold: float = attrs.field(default=0.05)
788785
angvel_moving_threshold: float = attrs.field(default=0.05)
789786

790-
def initial_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
791-
return (
792-
jnp.zeros(self.num_feet, dtype=jnp.float32),
793-
jnp.zeros(self.num_feet, dtype=jnp.float32),
794-
)
787+
def initial_carry(self, rng: PRNGKeyArray) -> Array:
788+
return jnp.zeros(self.num_feet, dtype=jnp.float32)
795789

796790
def get_reward_stateful(
797791
self,
798792
trajectory: Trajectory,
799-
reward_carry: tuple[Array, Array],
800-
) -> tuple[Array, tuple[Array, Array]]:
793+
reward_carry: Array,
794+
) -> tuple[Array, Array]:
801795
not_moving_lin = jnp.linalg.norm(trajectory.qvel[..., :2], axis=-1) < self.linvel_moving_threshold
802796
not_moving_ang = trajectory.qvel[..., 5] < self.angvel_moving_threshold
803797
not_moving = not_moving_lin & not_moving_ang
@@ -812,16 +806,14 @@ def get_reward_stateful(
812806
# Give a sparse reward once the foot contacts the ground, equal to the
813807
# maximum height of the foot since the last contact, thresholded at the
814808
# target height.
815-
def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array, Array]) -> tuple[tuple[Array, Array], Array]:
816-
(elapsed_time_n, max_height_n), (contact_n, position_n3, not_moving) = carry, x
809+
def scan_fn(carry: Array, x: tuple[Array, Array, Array]) -> tuple[Array, Array]:
810+
max_height_n, (contact_n, position_n3, not_moving) = carry, x
817811
height_n = position_n3[..., 2]
818-
scale = (elapsed_time_n / self.period).clip(max=1.0)
819812
reset = not_moving | contact_n
820-
reward_n = jnp.where(reset, max_height_n, 0.0).clip(max=self.height) * scale
813+
reward_n = jnp.where(reset, max_height_n, 0.0).clip(max=self.height)
821814
max_height_n = jnp.maximum(max_height_n, height_n)
822815
max_height_n = jnp.where(reset, 0.0, max_height_n)
823-
elapsed_time_n = jnp.where(reset, 0.0, elapsed_time_n + self.ctrl_dt)
824-
return (elapsed_time_n, max_height_n), reward_n
816+
return max_height_n, reward_n
825817

826818
reward_carry, reward_tn = xax.scan(
827819
scan_fn,
@@ -833,26 +825,39 @@ def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array, Array]) -> tuple[
833825

834826

835827
@attrs.define(frozen=True, kw_only=True)
836-
class FeetForcePenalty(Reward):
837-
"""Reward for reducing the force on the feet."""
828+
class ForcePenalty(StatefulReward):
829+
"""Reward for reducing the force on some body.
830+
831+
This is modeled with a low-pass filter to simulate compliance, since when
832+
using stiff contacts the force can sometimes be very high.
833+
"""
838834

839835
force_obs: str = attrs.field()
836+
ctrl_dt: float = attrs.field()
837+
ema_time: float = attrs.field(default=0.03)
838+
ema_scale: float = attrs.field(default=0.001)
839+
num_feet: int = attrs.field(default=2)
840840
bias: float = attrs.field(default=0.0)
841841

842-
def get_reward(self, trajectory: Trajectory) -> Array:
843-
force_t = (jnp.linalg.norm(trajectory.obs[self.force_obs], axis=-1) - self.bias).clip(min=0.0)
844-
return force_t.sum(axis=-1) ** 2
845-
842+
def initial_carry(self, rng: PRNGKeyArray) -> Array:
843+
return jnp.zeros(self.num_feet, dtype=jnp.float32)
846844

847-
@attrs.define(frozen=True, kw_only=True)
848-
class FeetTorquePenalty(Reward):
849-
"""Reward for reducing the force on the feet."""
845+
def get_reward_stateful(
846+
self,
847+
trajectory: Trajectory,
848+
reward_carry: Array,
849+
) -> tuple[Array, Array]:
850+
alpha = jnp.exp(-self.ctrl_dt / self.ema_time)
851+
obs = (jnp.linalg.norm(trajectory.obs[self.force_obs], axis=-1) - self.bias).clip(min=0)
850852

851-
torque_obs: str = attrs.field()
853+
def scan_fn(carry: Array, x: Array) -> tuple[Array, Array]:
854+
ema_n, obs_n = carry, x
855+
ema_n = alpha * ema_n + (1 - alpha) * obs_n
856+
return ema_n, ema_n
852857

853-
def get_reward(self, trajectory: Trajectory) -> Array:
854-
torque_t = jnp.linalg.norm(trajectory.obs[self.torque_obs], axis=-1)
855-
return torque_t.sum(axis=-1) ** 2
858+
ema_fn, ema_acc = xax.scan(scan_fn, reward_carry, obs)
859+
penalty = jnp.log1p(self.ema_scale * ema_acc).sum(axis=-1)
860+
return penalty, ema_fn
856861

857862

858863
@attrs.define(kw_only=True)

0 commit comments

Comments
 (0)