2828 "ReachabilityPenalty" ,
2929 "FeetAirTimeReward" ,
3030 "FeetHeightReward" ,
31- "FeetForcePenalty" ,
32- "FeetTorquePenalty" ,
31+ "ForcePenalty" ,
3332 "SinusoidalGaitReward" ,
3433 "BaseHeightTrackingReward" ,
3534]
@@ -778,26 +777,21 @@ def scan_fn(
778777class FeetHeightReward (StatefulReward ):
779778 """Reward for feet either touching or not touching the ground for some time."""
780779
781- period : float = attrs .field ()
782- ctrl_dt : float = attrs .field ()
783780 contact_obs : str = attrs .field ()
784781 position_obs : str = attrs .field ()
785782 height : float = attrs .field ()
786783 num_feet : int = attrs .field (default = 2 )
787784 linvel_moving_threshold : float = attrs .field (default = 0.05 )
788785 angvel_moving_threshold : float = attrs .field (default = 0.05 )
789786
790- def initial_carry (self , rng : PRNGKeyArray ) -> tuple [Array , Array ]:
791- return (
792- jnp .zeros (self .num_feet , dtype = jnp .float32 ),
793- jnp .zeros (self .num_feet , dtype = jnp .float32 ),
794- )
787+ def initial_carry (self , rng : PRNGKeyArray ) -> Array :
788+ return jnp .zeros (self .num_feet , dtype = jnp .float32 )
795789
796790 def get_reward_stateful (
797791 self ,
798792 trajectory : Trajectory ,
799- reward_carry : tuple [ Array , Array ] ,
800- ) -> tuple [Array , tuple [ Array , Array ] ]:
793+ reward_carry : Array ,
794+ ) -> tuple [Array , Array ]:
801795 not_moving_lin = jnp .linalg .norm (trajectory .qvel [..., :2 ], axis = - 1 ) < self .linvel_moving_threshold
802796 not_moving_ang = trajectory .qvel [..., 5 ] < self .angvel_moving_threshold
803797 not_moving = not_moving_lin & not_moving_ang
@@ -812,16 +806,14 @@ def get_reward_stateful(
812806 # Give a sparse reward once the foot contacts the ground, equal to the
813807 # maximum height of the foot since the last contact, thresholded at the
814808 # target height.
815- def scan_fn (carry : tuple [ Array , Array ], x : tuple [Array , Array , Array ]) -> tuple [tuple [ Array , Array ] , Array ]:
816- ( elapsed_time_n , max_height_n ) , (contact_n , position_n3 , not_moving ) = carry , x
809+ def scan_fn (carry : Array , x : tuple [Array , Array , Array ]) -> tuple [Array , Array ]:
810+ max_height_n , (contact_n , position_n3 , not_moving ) = carry , x
817811 height_n = position_n3 [..., 2 ]
818- scale = (elapsed_time_n / self .period ).clip (max = 1.0 )
819812 reset = not_moving | contact_n
820- reward_n = jnp .where (reset , max_height_n , 0.0 ).clip (max = self .height ) * scale
813+ reward_n = jnp .where (reset , max_height_n , 0.0 ).clip (max = self .height )
821814 max_height_n = jnp .maximum (max_height_n , height_n )
822815 max_height_n = jnp .where (reset , 0.0 , max_height_n )
823- elapsed_time_n = jnp .where (reset , 0.0 , elapsed_time_n + self .ctrl_dt )
824- return (elapsed_time_n , max_height_n ), reward_n
816+ return max_height_n , reward_n
825817
826818 reward_carry , reward_tn = xax .scan (
827819 scan_fn ,
@@ -833,26 +825,39 @@ def scan_fn(carry: tuple[Array, Array], x: tuple[Array, Array, Array]) -> tuple[
833825
834826
835827@attrs .define (frozen = True , kw_only = True )
836- class FeetForcePenalty (Reward ):
837- """Reward for reducing the force on the feet."""
828+ class ForcePenalty (StatefulReward ):
829+ """Reward for reducing the force on some body.
830+
831+ This is modeled with a low-pass filter to simulate compliance, since when
832+ using stiff contacts the force can sometimes be very high.
833+ """
838834
839835 force_obs : str = attrs .field ()
836+ ctrl_dt : float = attrs .field ()
837+ ema_time : float = attrs .field (default = 0.03 )
838+ ema_scale : float = attrs .field (default = 0.001 )
839+ num_feet : int = attrs .field (default = 2 )
840840 bias : float = attrs .field (default = 0.0 )
841841
842- def get_reward (self , trajectory : Trajectory ) -> Array :
843- force_t = (jnp .linalg .norm (trajectory .obs [self .force_obs ], axis = - 1 ) - self .bias ).clip (min = 0.0 )
844- return force_t .sum (axis = - 1 ) ** 2
845-
842+ def initial_carry (self , rng : PRNGKeyArray ) -> Array :
843+ return jnp .zeros (self .num_feet , dtype = jnp .float32 )
846844
847- @attrs .define (frozen = True , kw_only = True )
848- class FeetTorquePenalty (Reward ):
849- """Reward for reducing the force on the feet."""
845+ def get_reward_stateful (
846+ self ,
847+ trajectory : Trajectory ,
848+ reward_carry : Array ,
849+ ) -> tuple [Array , Array ]:
850+ alpha = jnp .exp (- self .ctrl_dt / self .ema_time )
851+ obs = (jnp .linalg .norm (trajectory .obs [self .force_obs ], axis = - 1 ) - self .bias ).clip (min = 0 )
850852
851- torque_obs : str = attrs .field ()
853+ def scan_fn (carry : Array , x : Array ) -> tuple [Array , Array ]:
854+ ema_n , obs_n = carry , x
855+ ema_n = alpha * ema_n + (1 - alpha ) * obs_n
856+ return ema_n , ema_n
852857
853- def get_reward ( self , trajectory : Trajectory ) -> Array :
854- torque_t = jnp .linalg . norm ( trajectory . obs [ self .torque_obs ], axis = - 1 )
855- return torque_t . sum ( axis = - 1 ) ** 2
858+ ema_fn , ema_acc = xax . scan ( scan_fn , reward_carry , obs )
859+ penalty = jnp .log1p ( self .ema_scale * ema_acc ). sum ( axis = - 1 )
860+ return penalty , ema_fn
856861
857862
858863@attrs .define (kw_only = True )
0 commit comments