2727 "LinkJerkPenalty" ,
2828 "ReachabilityPenalty" ,
2929 "FeetAirTimeReward" ,
30+ "FeetHeightReward" ,
3031 "FeetForcePenalty" ,
3132 "FeetTorquePenalty" ,
3233 "SinusoidalGaitReward" ,
4748from ksim .commands import AngularVelocityCommandValue , LinearVelocityCommandValue , SinusoidalGaitCommandValue
4849from ksim .types import PhysicsModel , Trajectory
4950from ksim .utils .mujoco import get_body_data_idx_from_name , get_qpos_data_idxs_by_name
50- from ksim .utils .validators import (
51- CartesianIndex ,
52- cartesian_index_to_dim ,
53- norm_validator ,
54- )
51+ from ksim .utils .validators import CartesianIndex , cartesian_index_to_dim , norm_validator
5552from ksim .vis import Marker
5653
5754logger = logging .getLogger (__name__ )
@@ -729,27 +726,22 @@ def get_reward(self, traj: Trajectory) -> jnp.ndarray:
729726class FeetAirTimeReward (StatefulReward ):
730727 """Reward for feet either touching or not touching the ground for some time."""
731728
732- threshold : float = attrs .field ()
729+ period : float = attrs .field ()
733730 ctrl_dt : float = attrs .field ()
734731 contact_obs : str = attrs .field ()
735- position_obs : str = attrs .field ()
736- height : float = attrs .field ()
737732 num_feet : int = attrs .field (default = 2 )
738733 bias : float = attrs .field (default = 0.0 )
739734 linvel_moving_threshold : float = attrs .field (default = 0.05 )
740735 angvel_moving_threshold : float = attrs .field (default = 0.05 )
741736
742- def initial_carry (self , rng : PRNGKeyArray ) -> tuple [Array , Array ]:
743- return (
744- jnp .zeros (self .num_feet , dtype = jnp .int32 ),
745- jnp .zeros (self .num_feet , dtype = jnp .float32 ),
746- )
737+ def initial_carry (self , rng : PRNGKeyArray ) -> Array :
738+ return jnp .zeros (self .num_feet , dtype = jnp .int32 )
747739
748740 def get_reward_stateful (
749741 self ,
750742 trajectory : Trajectory ,
751- reward_carry : tuple [ Array , Array ] ,
752- ) -> tuple [Array , tuple [ Array , Array ] ]:
743+ reward_carry : Array ,
744+ ) -> tuple [Array , Array ]:
753745 not_moving_lin = jnp .linalg .norm (trajectory .qvel [..., :2 ], axis = - 1 ) < self .linvel_moving_threshold
754746 not_moving_ang = trajectory .qvel [..., 5 ] < self .angvel_moving_threshold
755747 not_moving = not_moving_lin & not_moving_ang
@@ -758,39 +750,77 @@ def get_reward_stateful(
758750 sensor_data_tn = sensor_data_tcn .any (axis = - 2 )
759751 chex .assert_shape (sensor_data_tn , (..., self .num_feet ))
760752
761- position_tn3 = trajectory .obs [self .position_obs ]
762- chex .assert_shape (position_tn3 , (..., self .num_feet , 3 ))
763-
764- threshold_steps = round (self .threshold / self .ctrl_dt )
753+ threshold_steps = round (self .period / self .ctrl_dt )
765754
766755 def scan_fn (
767- carry : tuple [ Array , Array ] ,
768- x : tuple [Array , Array , Array , Array ],
769- ) -> tuple [tuple [ Array , Array ], tuple [ Array , Array ] ]:
770- ( count_n , max_height_n ), (contact_n , position_n3 , not_moving , done ) = carry , x
756+ carry : Array ,
757+ x : tuple [Array , Array , Array ],
758+ ) -> tuple [Array , Array ]:
759+ count_n , (contact_n , not_moving , done ) = carry , x
771760 reset = done | not_moving | contact_n
772761 count_n = jnp .where (reset , 0 , count_n + 1 )
762+ return count_n , count_n
773763
774- height_n = position_n3 [..., 2 ]
775- max_height_n = jnp .where (reset , 0.0 , jnp .maximum (max_height_n , height_n ))
776-
777- return (count_n , max_height_n ), (count_n , max_height_n )
778-
779- reward_carry , (count_tn , max_height_tn ) = xax .scan (
764+ reward_carry , count_tn = xax .scan (
780765 scan_fn ,
781766 reward_carry ,
782- (sensor_data_tn , position_tn3 , not_moving , trajectory .done ),
767+ (sensor_data_tn , not_moving , trajectory .done ),
783768 )
784769
785770 # Gradually increase reward until `threshold_steps`.
786771 reward_tn = (count_tn .astype (jnp .float32 ) / threshold_steps ) + self .bias
787772 reward_tn = jnp .where ((count_tn > 0 ) & (count_tn < threshold_steps ), reward_tn , 0.0 )
773+ reward_t = reward_tn .sum (axis = - 1 )
774+ return reward_t , reward_carry
788775
789- # Scale the reward according to the max height.
790- reward_tn = reward_tn * max_height_tn .clip (max = self .height ) / self .height
791776
792- reward_t = reward_tn .sum (axis = - 1 )
777+ @attrs .define (frozen = True , kw_only = True )
778+ class FeetHeightReward (StatefulReward ):
779+ """Reward for feet either touching or not touching the ground for some time."""
780+
781+ period : float = attrs .field ()
782+ ctrl_dt : float = attrs .field ()
783+ contact_obs : str = attrs .field ()
784+ position_obs : str = attrs .field ()
785+ height : float = attrs .field ()
786+ num_feet : int = attrs .field (default = 2 )
787+ bias : float = attrs .field (default = 0.0 )
788+ linvel_moving_threshold : float = attrs .field (default = 0.05 )
789+ angvel_moving_threshold : float = attrs .field (default = 0.05 )
790+
791+ def initial_carry (self , rng : PRNGKeyArray ) -> tuple [Array , Array ]:
792+ return (
793+ jnp .zeros (self .num_feet , dtype = jnp .float32 ),
794+ jnp .zeros (self .num_feet , dtype = jnp .float32 ),
795+ )
793796
797+ def get_reward_stateful (
798+ self ,
799+ trajectory : Trajectory ,
800+ reward_carry : tuple [Array , Array ],
801+ ) -> tuple [Array , tuple [Array , Array ]]:
802+ contact_tcn = trajectory .obs [self .contact_obs ] > 0.5 # Values are either 0 or 1.
803+ contact_tn = contact_tcn .any (axis = - 2 )
804+ chex .assert_shape (contact_tn , (..., self .num_feet ))
805+
806+ position_tn3 = trajectory .obs [self .position_obs ]
807+ chex .assert_shape (position_tn3 , (..., self .num_feet , 3 ))
808+
809+ # Give a sparse reward once the foot contacts the ground, equal to the
810+ # maximum height of the foot since the last contact, thresholded at the
811+ # target height.
812+ def scan_fn (carry : tuple [Array , Array ], x : tuple [Array , Array ]) -> tuple [tuple [Array , Array ], Array ]:
813+ (elapsed_time_n , max_height_n ), (contact_n , position_n3 ) = carry , x
814+ height_n = position_n3 [..., 2 ]
815+ scale = (elapsed_time_n / self .period ).clip (max = 1.0 )
816+ reward_n = jnp .where (contact_n , max_height_n , 0.0 ).clip (max = self .height ) * scale
817+ max_height_n = jnp .maximum (max_height_n , height_n )
818+ max_height_n = jnp .where (contact_n , 0.0 , max_height_n )
819+ elapsed_time_n = jnp .where (contact_n , 0.0 , elapsed_time_n + self .ctrl_dt )
820+ return (elapsed_time_n , max_height_n ), reward_n
821+
822+ reward_carry , reward_tn = xax .scan (scan_fn , reward_carry , (contact_tn , position_tn3 ))
823+ reward_t = reward_tn .max (axis = - 1 )
794824 return reward_t , reward_carry
795825
796826
0 commit comments