66 "Reward" ,
77 "StatefulReward" ,
88 "StayAliveReward" ,
9- "LinearVelocityPenalty " ,
10- "AngularVelocityPenalty " ,
11- "OffAxisVelocityPenalty " ,
9+ "LinearVelocityReward " ,
10+ "AngularVelocityReward " ,
11+ "OffAxisVelocityReward " ,
1212 "BaseHeightReward" ,
1313 "BaseHeightRangeReward" ,
1414 "ActionVelocityPenalty" ,
@@ -239,11 +239,12 @@ def get(
239239
240240
241241@attrs .define (frozen = True , kw_only = True )
242- class LinearVelocityPenalty (Reward ):
242+ class LinearVelocityReward (Reward ):
243243 """Penalty for how fast the robot is moving in the z-direction."""
244244
245245 cmd : str = attrs .field ()
246- deadzone : float = attrs .field (default = 0.01 )
246+ vel_length_scale : float = attrs .field (default = 0.25 )
247+ yaw_length_scale : float = attrs .field (default = 0.25 )
247248 zero_threshold : float = attrs .field (default = 0.01 )
248249 vis_height : float = attrs .field (default = 0.6 )
249250
@@ -262,63 +263,56 @@ def get_reward(self, trajectory: Trajectory) -> dict[str, Array]:
262263 # Don't reward if the command is zero.
263264 is_zero = jnp .abs (cmd .vel ) < self .zero_threshold
264265
265- vel_diff = ( jnp .abs (vel - cmd .vel ) - self .deadzone ). clip ( min = 0.0 )
266- yaw_diff = ( jnp .abs (yaw - cmd .yaw ) - self .deadzone ). clip ( min = 0.0 )
267- x_diff = ( jnp .abs (x - cmd .xvel ) - self .deadzone ). clip ( min = 0.0 )
268- y_diff = ( jnp .abs (y - cmd .yvel ) - self .deadzone ). clip ( min = 0.0 )
266+ vel_rew = jnp . exp ( - jnp .square (vel - cmd .vel ) / ( 2 * self .vel_length_scale ** 2 ) )
267+ yaw_rew = jnp . exp ( - jnp .square (yaw - cmd .yaw ) / ( 2 * self .yaw_length_scale ** 2 ) )
268+ x_rew = jnp . exp ( - jnp .square (x - cmd .xvel ) / ( 2 * self .vel_length_scale ** 2 ) )
269+ y_rew = jnp . exp ( - jnp .square (y - cmd .yvel ) / ( 2 * self .vel_length_scale ** 2 ) )
269270
270271 return {
271- "vel_l1" : vel_diff ,
272- "vel_l2" : vel_diff ** 2 ,
273- "yaw_l1" : jnp .where (is_zero , 0.0 , yaw_diff ),
274- "yaw_l2" : jnp .where (is_zero , 0.0 , yaw_diff ** 2 ),
275- "x_l1" : x_diff ,
276- "x_l2" : x_diff ** 2 ,
277- "y_l1" : y_diff ,
278- "y_l2" : y_diff ** 2 ,
272+ "vel" : vel_rew ,
273+ "yaw" : jnp .where (is_zero , 0.0 , yaw_rew ),
274+ "x" : x_rew ,
275+ "y" : y_rew ,
279276 }
280277
281278 def get_markers (self , name : str ) -> Collection [Marker ]:
282279 return [LinearVelocityPenaltyMarker .get (height = self .vis_height )]
283280
284281
285282@attrs .define (frozen = True , kw_only = True )
286- class AngularVelocityPenalty (Reward ):
283+ class AngularVelocityReward (Reward ):
287284 """Penalty for how fast the robot is rotating in the xy-plane."""
288285
289286 cmd : str = attrs .field ()
290- deadzone : float = attrs .field (default = 0.01 )
287+ angvel_length_scale : float = attrs .field (default = 0.25 )
291288
292289 def get_reward (self , trajectory : Trajectory ) -> dict [str , Array ]:
293290 cmd : AngularVelocityCommandValue = trajectory .command [self .cmd ]
294291 angvel = trajectory .qvel [..., 5 ]
295- angvel_diff = ( jnp .abs (angvel - cmd .vel ) - self .deadzone ). clip ( min = 0.0 )
292+ angvel_rew = jnp . exp ( - jnp .square (angvel - cmd .vel ) / ( 2 * self .angvel_length_scale ** 2 ) )
296293 return {
297- "angvel_l1" : angvel_diff ,
298- "angvel_l2" : angvel_diff ** 2 ,
294+ "angvel" : angvel_rew ,
299295 }
300296
301297
302298@attrs .define (frozen = True , kw_only = True )
303- class OffAxisVelocityPenalty (Reward ):
299+ class OffAxisVelocityReward (Reward ):
304300 """Penalizes velocities in the off-command directions."""
305301
306- deadzone : float = attrs .field (default = 0.01 )
302+ lin_length_scale : float = attrs .field (default = 0.25 )
303+ ang_length_scale : float = attrs .field (default = 0.25 )
307304
308305 def get_reward (self , trajectory : Trajectory ) -> dict [str , Array ]:
309306 linz = trajectory .qvel [..., 2 ]
310307 angx = trajectory .qvel [..., 4 ]
311308 angy = trajectory .qvel [..., 5 ]
312- linz_diff = ( jnp .abs (linz ) - self .deadzone ). clip ( min = 0.0 )
313- angx_diff = ( jnp .abs (angx ) - self .deadzone ). clip ( min = 0.0 )
314- angy_diff = ( jnp .abs (angy ) - self .deadzone ). clip ( min = 0.0 )
309+ linz_rew = jnp . exp ( - jnp .square (linz ) / ( 2 * self .lin_length_scale ** 2 ) )
310+ angx_rew = jnp . exp ( - jnp .square (angx ) / ( 2 * self .ang_length_scale ** 2 ) )
311+ angy_rew = jnp . exp ( - jnp .square (angy ) / ( 2 * self .ang_length_scale ** 2 ) )
315312 return {
316- "linz_l1" : linz_diff ,
317- "linz_l2" : linz_diff ** 2 ,
318- "angx_l1" : angx_diff ,
319- "angx_l2" : angx_diff ** 2 ,
320- "angy_l1" : angy_diff ,
321- "angy_l2" : angy_diff ** 2 ,
313+ "linz" : linz_rew ,
314+ "angx" : angx_rew ,
315+ "angy" : angy_rew ,
322316 }
323317
324318
0 commit comments