Skip to content

Commit 88c8c2d

Browse files
authored
add pos / neg toggle for dist termination (#450)
1 parent a459739 commit 88c8c2d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ksim/terminations.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,27 +172,31 @@ def __call__(self, state: PhysicsData, curriculum_level: Array) -> Array:
172172
class FarFromOriginTermination(Termination):
173173
"""Terminates the episode if the robot is too far from the origin.
174174
175-
This is treated as a positive termination.
175+
Defaults to a positive termination.
176176
"""
177177

178178
max_dist: float = attrs.field(validator=attrs.validators.gt(0.0))
179+
pos_termination: bool = attrs.field(default=True)
179180

180181
def __call__(self, state: PhysicsData, curriculum_level: Array) -> Array:
181-
return jnp.where(jnp.linalg.norm(state.qpos[..., :3], axis=-1) > self.max_dist, 1, 0)
182+
termination_value = 1 if self.pos_termination else -1
183+
return jnp.where(jnp.linalg.norm(state.qpos[..., :3], axis=-1) > self.max_dist, termination_value, 0)
182184

183185

184186
@attrs.define(frozen=True, kw_only=True)
185187
class EpisodeLengthTermination(Termination):
186188
"""Terminates the episode if the robot has been alive for too long.
187189
188-
This is treated as a positive termination.
190+
This defaults to a positive termination.
189191
"""
190192

191193
max_length_sec: float = attrs.field(validator=attrs.validators.gt(0.0))
192194
disable_at_curriculum_level: int = attrs.field(default=None)
195+
pos_termination: bool = attrs.field(default=True)
193196

194197
def __call__(self, state: PhysicsData, curriculum_level: Array) -> Array:
195-
long_episodes = jnp.where(state.time > self.max_length_sec, 1, 0)
198+
termination_value = 1 if self.pos_termination else -1
199+
long_episodes = jnp.where(state.time > self.max_length_sec, termination_value, 0)
196200
if self.disable_at_curriculum_level is not None:
197201
return jnp.where(curriculum_level < self.disable_at_curriculum_level, 0, long_episodes)
198202

0 commit comments

Comments
 (0)