@@ -746,9 +746,9 @@ def get_reward_stateful(
746746 not_moving_ang = trajectory .qvel [..., 5 ] < self .angvel_moving_threshold
747747 not_moving = not_moving_lin & not_moving_ang
748748
749- sensor_data_tcn = trajectory .obs [self .contact_obs ] > 0.5 # Values are either 0 or 1.
750- sensor_data_tn = sensor_data_tcn .any (axis = - 2 )
751- chex .assert_shape (sensor_data_tn , (..., self .num_feet ))
749+ contact_tcn = trajectory .obs [self .contact_obs ] > 0.5 # Values are either 0 or 1.
750+ contact_tn = contact_tcn .any (axis = - 2 )
751+ chex .assert_shape (contact_tn , (..., self .num_feet ))
752752
753753 threshold_steps = round (self .period / self .ctrl_dt )
754754
@@ -764,7 +764,7 @@ def scan_fn(
764764 reward_carry , count_tn = xax .scan (
765765 scan_fn ,
766766 reward_carry ,
767- (sensor_data_tn , not_moving , trajectory .done ),
767+ (contact_tn , not_moving , trajectory .done ),
768768 )
769769
770770 # Gradually increase reward until `threshold_steps`.
@@ -784,7 +784,6 @@ class FeetHeightReward(StatefulReward):
784784 position_obs : str = attrs .field ()
785785 height : float = attrs .field ()
786786 num_feet : int = attrs .field (default = 2 )
787- bias : float = attrs .field (default = 0.0 )
788787 linvel_moving_threshold : float = attrs .field (default = 0.05 )
789788 angvel_moving_threshold : float = attrs .field (default = 0.05 )
790789
@@ -799,6 +798,10 @@ def get_reward_stateful(
799798 trajectory : Trajectory ,
800799 reward_carry : tuple [Array , Array ],
801800 ) -> tuple [Array , tuple [Array , Array ]]:
801+ not_moving_lin = jnp .linalg .norm (trajectory .qvel [..., :2 ], axis = - 1 ) < self .linvel_moving_threshold
802+ not_moving_ang = trajectory .qvel [..., 5 ] < self .angvel_moving_threshold
803+ not_moving = not_moving_lin & not_moving_ang
804+
802805 contact_tcn = trajectory .obs [self .contact_obs ] > 0.5 # Values are either 0 or 1.
803806 contact_tn = contact_tcn .any (axis = - 2 )
804807 chex .assert_shape (contact_tn , (..., self .num_feet ))
@@ -809,17 +812,22 @@ def get_reward_stateful(
809812 # Give a sparse reward once the foot contacts the ground, equal to the
810813 # maximum height of the foot since the last contact, thresholded at the
811814 # target height.
812- def scan_fn (carry : tuple [Array , Array ], x : tuple [Array , Array ]) -> tuple [tuple [Array , Array ], Array ]:
813- (elapsed_time_n , max_height_n ), (contact_n , position_n3 ) = carry , x
815+ def scan_fn (carry : tuple [Array , Array ], x : tuple [Array , Array , Array ]) -> tuple [tuple [Array , Array ], Array ]:
816+ (elapsed_time_n , max_height_n ), (contact_n , position_n3 , not_moving ) = carry , x
814817 height_n = position_n3 [..., 2 ]
815818 scale = (elapsed_time_n / self .period ).clip (max = 1.0 )
816- reward_n = jnp .where (contact_n , max_height_n , 0.0 ).clip (max = self .height ) * scale
819+ reset = not_moving | contact_n
820+ reward_n = jnp .where (reset , max_height_n , 0.0 ).clip (max = self .height ) * scale
817821 max_height_n = jnp .maximum (max_height_n , height_n )
818- max_height_n = jnp .where (contact_n , 0.0 , max_height_n )
819- elapsed_time_n = jnp .where (contact_n , 0.0 , elapsed_time_n + self .ctrl_dt )
822+ max_height_n = jnp .where (reset , 0.0 , max_height_n )
823+ elapsed_time_n = jnp .where (reset , 0.0 , elapsed_time_n + self .ctrl_dt )
820824 return (elapsed_time_n , max_height_n ), reward_n
821825
822- reward_carry , reward_tn = xax .scan (scan_fn , reward_carry , (contact_tn , position_tn3 ))
826+ reward_carry , reward_tn = xax .scan (
827+ scan_fn ,
828+ reward_carry ,
829+ (contact_tn , position_tn3 , not_moving | trajectory .done ),
830+ )
823831 reward_t = reward_tn .max (axis = - 1 )
824832 return reward_t , reward_carry
825833
0 commit comments