Skip to content
143 changes: 86 additions & 57 deletions source/isaaclab/isaaclab/envs/manager_based_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,31 +336,7 @@ def reset(
Returns:
A tuple containing the observations and extras.
"""
if env_ids is None:
env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)

# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(env_ids)

# set the seed
if seed is not None:
self.seed(seed)

# reset state of scene
self._reset_idx(env_ids)

# update articulation kinematics
self.scene.write_data_to_sim()
self.sim.forward()
# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()

# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(env_ids)

# compute observations
self.obs_buf = self.observation_manager.compute(update_history=True)
self.obs_buf = self._reset(env_ids, None, seed)

if self.cfg.wait_for_textures and self.sim.has_rtx_sensors():
while SimulationManager.assets_loading():
Expand All @@ -375,7 +351,7 @@ def reset_to(
env_ids: Sequence[int] | None,
seed: int | None = None,
is_relative: bool = False,
):
) -> tuple[VecEnvObs, dict]:
"""Resets specified environments to provided states.

This function resets the environments to the provided states. The state is a dictionary
Expand All @@ -390,39 +366,16 @@ def reset_to(
:meth:`InteractiveScene.get_state` for the format.
env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset.
seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
is_relative: If set to True, the state is considered relative to the environment origins.
Defaults to False.
"""
# reset all envs in the scene if env_ids is None
if env_ids is None:
env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)

# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(env_ids)

# set the seed
if seed is not None:
self.seed(seed)

self._reset_idx(env_ids)
is_relative: If True, the state is considered relative to the environment origins. Defaults to False.

# set the state
self.scene.reset_to(state, env_ids, is_relative=is_relative)

# update articulation kinematics
self.sim.forward()

# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()

# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(env_ids)
Returns:
A tuple containing the observations and extras.
"""
if state is None:
raise ValueError("state cannot be None!")

# compute observations
self.obs_buf = self.observation_manager.compute(update_history=True)
self.obs_buf = self._reset(env_ids, state, seed, is_relative)

# return observations
return self.obs_buf, self.extras

def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]:
Expand Down Expand Up @@ -471,7 +424,7 @@ def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]:
self.event_manager.apply(mode="interval", dt=self.step_dt)

# -- compute observations
self.obs_buf = self.observation_manager.compute(update_history=True)
self.obs_buf = self._get_observations(update_history=True)
self.recorder_manager.record_post_step()

# return observations and extras
Expand Down Expand Up @@ -529,6 +482,82 @@ def close(self):
Helper functions.
"""

def _get_observations(self, update_history: bool = False) -> VecEnvObs:
"""
Computes and returns the current observation dictionary for the environment.

Args:
update_history: The boolean indicator without return obs should be appended to observation history.
Default to False, in which case calling compute_group does not modify history. This input is no-ops
if the group's history_length == 0.

Returns:
A dictionary containing the full set of observations.
"""
return self.observation_manager.compute(update_history)

def _reset(
self,
env_ids: Sequence[int] | None,
state: dict[str, dict[str, dict[str, torch.Tensor]]] | None = None,
seed: int | None = None,
is_relative: bool = False,
) -> VecEnvObs:
"""Reset the specified environments to a given or randomized state.

If a ``state`` is provided, the environments are restored accordingly.
Otherwise, they are reset using the environment randomization logic.

This function calls the :meth:`_reset_idx` function to reset the specified environments.
However, certain operations, such as procedural terrain generation, that happened during initialization
are not repeated.

Args:
env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset.
state: The state is a dictionary containing the state of the scene entities. Defaults to None.
Please refer to :meth:`InteractiveScene.get_state` for the format.
seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
is_relative: If True, the state is considered relative to the environment origins. Defaults to False.

Returns:
A dictionary containing the full set of observations.
"""
# reset all envs in the scene if env_ids is None
if env_ids is None:
env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)

# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(env_ids)

# set the seed
if seed is not None:
self.seed(seed)

# reset state of scene
self._reset_idx(env_ids)

# set the state
if state is None:
self.scene.write_data_to_sim()
else:
self.scene.reset_to(state, env_ids, is_relative=is_relative)

# update articulation kinematics
self.sim.forward()

# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()

# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(env_ids)

# compute observations
self.obs_buf = self._get_observations(update_history=True)

# return observations
return self.obs_buf

def _reset_idx(self, env_ids: Sequence[int]):
"""Reset environments based on specified indices.

Expand Down
4 changes: 2 additions & 2 deletions source/isaaclab/isaaclab/envs/manager_based_rl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:

if len(self.recorder_manager.active_terms) > 0:
# update observations for recording if needed
self.obs_buf = self.observation_manager.compute()
self.obs_buf = self._get_observations()
self.recorder_manager.record_post_step()

# -- reset envs that terminated/timed-out and log the episode information
Expand All @@ -235,7 +235,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:
self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations
# note: done after reset to get the correct observations for reset envs
self.obs_buf = self.observation_manager.compute(update_history=True)
self.obs_buf = self._get_observations(update_history=True)

# return observations, rewards, resets and extras
return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras
Expand Down
14 changes: 14 additions & 0 deletions source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def process_actions(self, actions: torch.Tensor):
def reset(self, env_ids: Sequence[int] | None = None) -> None:
self._raw_actions[env_ids] = 0.0

def normalize_processed_actions(self, processed_actions: torch.Tensor) -> torch.Tensor:
offset_free_actions = processed_actions - self._offset
return torch.where(self._scale > 1e-8, offset_free_actions / self._scale, offset_free_actions)


class JointPositionAction(JointAction):
"""Joint action term that applies the processed actions to the articulation's joints as position commands."""
Expand Down Expand Up @@ -228,6 +232,16 @@ def apply_actions(self):
# set position targets
self._asset.set_joint_position_target(current_actions, joint_ids=self._joint_ids)

def normalize_processed_actions(self, processed_actions: torch.Tensor) -> torch.Tensor:
"""Normalization of processed actions is not supported.

This method cannot be applied since the transformation is performed during the action application
stage (:meth:`apply_actions`) rather than during processing (:meth:`process_actions`).
"""
raise NotImplementedError(
f"Normalizing of the processed actions is not supported for {self.__class__.__name__}."
)


class JointVelocityAction(JointAction):
"""Joint action term that applies the processed actions to the articulation's joints as velocity commands."""
Expand Down
15 changes: 15 additions & 0 deletions source/isaaclab/isaaclab/managers/action_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ def apply_actions(self):
"""
raise NotImplementedError

def normalize_processed_actions(self, processed_actions: torch.Tensor) -> torch.Tensor:
"""Maps the processed actions to the normalized action space.

This function takes processed (e.g., scaled or shifted) actions and applies the inverse
transformation to match the expected normalized action range outputted by the policy.

Args:
processed_actions: The processed actions, typically scaled or shifted of the policy output.
Returns:
A tensor of actions mapped back to the normalized action space.
"""
raise NotImplementedError(
f"Normalizing of the processed actions is not implemented for {self.__class__.__name__}."
)

def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
This function is responsible for creating the visualization objects if they don't exist
Expand Down
17 changes: 10 additions & 7 deletions source/isaaclab/isaaclab/managers/observation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class ObservationManager(ManagerBase):
The observations are clipped and scaled as per the configuration settings.
"""

# Configuration fields to skip when parsing observation terms cfg
_EXCLUDED_CFG_KEYS: tuple[str, ...] = (
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
)

def __init__(self, cfg: object, env: ManagerBasedEnv):
"""Initialize observation manager.

Expand Down Expand Up @@ -515,13 +524,7 @@ def _prepare_terms(self):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in [
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
]:
if term_name in self._EXCLUDED_CFG_KEYS:
continue
# check for non config
if term_cfg is None:
Expand Down
5 changes: 1 addition & 4 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ def reset(self) -> tuple[TensorDict, dict]: # noqa: D102

def get_observations(self) -> TensorDict:
"""Returns the current observations of the environment."""
if hasattr(self.unwrapped, "observation_manager"):
obs_dict = self.unwrapped.observation_manager.compute()
else:
obs_dict = self.unwrapped._get_observations()
obs_dict = self.unwrapped._get_observations()
return TensorDict(obs_dict, batch_size=[self.num_envs])

def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.Tensor, dict]:
Expand Down
Loading