@@ -172,27 +172,31 @@ def __call__(self, state: PhysicsData, curriculum_level: Array) -> Array:
172172class FarFromOriginTermination (Termination ):
173173 """Terminates the episode if the robot is too far from the origin.
174174
175- This is treated as a positive termination.
175+ Defaults to a positive termination.
176176 """
177177
178178 max_dist : float = attrs .field (validator = attrs .validators .gt (0.0 ))
179+ pos_termination : bool = attrs .field (default = True )
179180
180181 def __call__ (self , state : PhysicsData , curriculum_level : Array ) -> Array :
181- return jnp .where (jnp .linalg .norm (state .qpos [..., :3 ], axis = - 1 ) > self .max_dist , 1 , 0 )
182+ termination_value = 1 if self .pos_termination else - 1
183+ return jnp .where (jnp .linalg .norm (state .qpos [..., :3 ], axis = - 1 ) > self .max_dist , termination_value , 0 )
182184
183185
184186@attrs .define (frozen = True , kw_only = True )
185187class EpisodeLengthTermination (Termination ):
186188 """Terminates the episode if the robot has been alive for too long.
187189
188- This is treated as a positive termination.
190+ This defaults to a positive termination.
189191 """
190192
191193 max_length_sec : float = attrs .field (validator = attrs .validators .gt (0.0 ))
192194 disable_at_curriculum_level : int = attrs .field (default = None )
195+ pos_termination : bool = attrs .field (default = True )
193196
194197 def __call__ (self , state : PhysicsData , curriculum_level : Array ) -> Array :
195- long_episodes = jnp .where (state .time > self .max_length_sec , 1 , 0 )
198+ termination_value = 1 if self .pos_termination else - 1
199+ long_episodes = jnp .where (state .time > self .max_length_sec , termination_value , 0 )
196200 if self .disable_at_curriculum_level is not None :
197201 return jnp .where (curriculum_level < self .disable_at_curriculum_level , 0 , long_episodes )
198202
0 commit comments