Skip to content

Commit bc20e65

Browse files
authored
use exp kernel (#515)
* use exp kernel * off axis too * lint
1 parent 1a4b9ec commit bc20e65

File tree

2 files changed

+31
-37
lines changed

2 files changed

+31
-37
lines changed

examples/walking.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,13 +504,13 @@ def get_rewards(self, physics_model: ksim.PhysicsModel) -> dict[str, ksim.Reward
504504
return {
505505
"stay_alive": ksim.StayAliveReward(scale=100.0),
506506
"upright": ksim.UprightReward(scale=5.0),
507-
"linvel": ksim.LinearVelocityPenalty(
507+
"linvel": ksim.LinearVelocityReward(
508508
cmd="linvel",
509-
scale=-0.1,
509+
scale=0.1,
510510
),
511-
"angvel": ksim.AngularVelocityPenalty(
511+
"angvel": ksim.AngularVelocityReward(
512512
cmd="angvel",
513-
scale=-0.01,
513+
scale=0.01,
514514
),
515515
"foot_airtime": ksim.FeetAirTimeReward(
516516
ctrl_dt=self.config.ctrl_dt,

ksim/rewards.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
"Reward",
77
"StatefulReward",
88
"StayAliveReward",
9-
"LinearVelocityPenalty",
10-
"AngularVelocityPenalty",
11-
"OffAxisVelocityPenalty",
9+
"LinearVelocityReward",
10+
"AngularVelocityReward",
11+
"OffAxisVelocityReward",
1212
"BaseHeightReward",
1313
"BaseHeightRangeReward",
1414
"ActionVelocityPenalty",
@@ -239,11 +239,12 @@ def get(
239239

240240

241241
@attrs.define(frozen=True, kw_only=True)
242-
class LinearVelocityPenalty(Reward):
242+
class LinearVelocityReward(Reward):
243243
"""Penalty for how fast the robot is moving in the z-direction."""
244244

245245
cmd: str = attrs.field()
246-
deadzone: float = attrs.field(default=0.01)
246+
vel_length_scale: float = attrs.field(default=0.25)
247+
yaw_length_scale: float = attrs.field(default=0.25)
247248
zero_threshold: float = attrs.field(default=0.01)
248249
vis_height: float = attrs.field(default=0.6)
249250

@@ -262,63 +263,56 @@ def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
262263
# Don't reward if the command is zero.
263264
is_zero = jnp.abs(cmd.vel) < self.zero_threshold
264265

265-
vel_diff = (jnp.abs(vel - cmd.vel) - self.deadzone).clip(min=0.0)
266-
yaw_diff = (jnp.abs(yaw - cmd.yaw) - self.deadzone).clip(min=0.0)
267-
x_diff = (jnp.abs(x - cmd.xvel) - self.deadzone).clip(min=0.0)
268-
y_diff = (jnp.abs(y - cmd.yvel) - self.deadzone).clip(min=0.0)
266+
vel_rew = jnp.exp(-jnp.square(vel - cmd.vel) / (2 * self.vel_length_scale**2))
267+
yaw_rew = jnp.exp(-jnp.square(yaw - cmd.yaw) / (2 * self.yaw_length_scale**2))
268+
x_rew = jnp.exp(-jnp.square(x - cmd.xvel) / (2 * self.vel_length_scale**2))
269+
y_rew = jnp.exp(-jnp.square(y - cmd.yvel) / (2 * self.vel_length_scale**2))
269270

270271
return {
271-
"vel_l1": vel_diff,
272-
"vel_l2": vel_diff**2,
273-
"yaw_l1": jnp.where(is_zero, 0.0, yaw_diff),
274-
"yaw_l2": jnp.where(is_zero, 0.0, yaw_diff**2),
275-
"x_l1": x_diff,
276-
"x_l2": x_diff**2,
277-
"y_l1": y_diff,
278-
"y_l2": y_diff**2,
272+
"vel": vel_rew,
273+
"yaw": jnp.where(is_zero, 0.0, yaw_rew),
274+
"x": x_rew,
275+
"y": y_rew,
279276
}
280277

281278
def get_markers(self, name: str) -> Collection[Marker]:
282279
return [LinearVelocityPenaltyMarker.get(height=self.vis_height)]
283280

284281

285282
@attrs.define(frozen=True, kw_only=True)
286-
class AngularVelocityPenalty(Reward):
283+
class AngularVelocityReward(Reward):
287284
"""Penalty for how fast the robot is rotating in the xy-plane."""
288285

289286
cmd: str = attrs.field()
290-
deadzone: float = attrs.field(default=0.01)
287+
angvel_length_scale: float = attrs.field(default=0.25)
291288

292289
def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
293290
cmd: AngularVelocityCommandValue = trajectory.command[self.cmd]
294291
angvel = trajectory.qvel[..., 5]
295-
angvel_diff = (jnp.abs(angvel - cmd.vel) - self.deadzone).clip(min=0.0)
292+
angvel_rew = jnp.exp(-jnp.square(angvel - cmd.vel) / (2 * self.angvel_length_scale**2))
296293
return {
297-
"angvel_l1": angvel_diff,
298-
"angvel_l2": angvel_diff**2,
294+
"angvel": angvel_rew,
299295
}
300296

301297

302298
@attrs.define(frozen=True, kw_only=True)
303-
class OffAxisVelocityPenalty(Reward):
299+
class OffAxisVelocityReward(Reward):
304300
"""Penalizes velocities in the off-command directions."""
305301

306-
deadzone: float = attrs.field(default=0.01)
302+
lin_length_scale: float = attrs.field(default=0.25)
303+
ang_length_scale: float = attrs.field(default=0.25)
307304

308305
def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
309306
linz = trajectory.qvel[..., 2]
310307
angx = trajectory.qvel[..., 4]
311308
angy = trajectory.qvel[..., 5]
312-
linz_diff = (jnp.abs(linz) - self.deadzone).clip(min=0.0)
313-
angx_diff = (jnp.abs(angx) - self.deadzone).clip(min=0.0)
314-
angy_diff = (jnp.abs(angy) - self.deadzone).clip(min=0.0)
309+
linz_rew = jnp.exp(-jnp.square(linz) / (2 * self.lin_length_scale**2))
310+
angx_rew = jnp.exp(-jnp.square(angx) / (2 * self.ang_length_scale**2))
311+
angy_rew = jnp.exp(-jnp.square(angy) / (2 * self.ang_length_scale**2))
315312
return {
316-
"linz_l1": linz_diff,
317-
"linz_l2": linz_diff**2,
318-
"angx_l1": angx_diff,
319-
"angx_l2": angx_diff**2,
320-
"angy_l1": angy_diff,
321-
"angy_l2": angy_diff**2,
313+
"linz": linz_rew,
314+
"angx": angx_rew,
315+
"angy": angy_rew,
322316
}
323317

324318

0 commit comments

Comments
 (0)