|
12 | 12 | from typing import TYPE_CHECKING |
13 | 13 |
|
14 | 14 | from isaaclab.assets.articulation import Articulation |
| 15 | +import isaaclab.utils.math as math_utils |
15 | 16 |
|
16 | 17 | from isaaclab_arena_g1.g1_env.mdp.actions.g1_decoupled_wbc_joint_action import G1DecoupledWBCJointAction |
17 | 18 | from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.g1_wbc_upperbody_ik.g1_wbc_upperbody_controller import ( |
@@ -360,3 +361,35 @@ def reset(self, env_ids: Sequence[int] | None = None) -> None: |
360 | 361 | env_ids: A list of environment IDs to reset. If None, all environments are reset. |
361 | 362 | """ |
362 | 363 | self._raw_actions[env_ids] = torch.zeros(self.action_dim, device=self.device) |
| 364 | + |
| 365 | + def preprocess_actions(self, actions: torch.Tensor) -> torch.Tensor: |
| 366 | + """Transform wrist positions and orientations from world frame to robot base frame. |
| 367 | +
|
| 368 | + Args: |
| 369 | + actions: The input actions tensor, shape (action_dim,) or (1, action_dim). |
| 370 | +
|
| 371 | + Returns: |
| 372 | + The processed actions tensor (same shape as input). |
| 373 | + """ |
| 374 | + actions = actions.clone() |
| 375 | + |
| 376 | + robot_base_pos = self._asset.data.root_link_pos_w[0, :3] |
| 377 | + robot_base_quat = self._asset.data.root_link_quat_w[0] |
| 378 | + |
| 379 | + wrist_pos_world = torch.stack([actions[0, 2:5], actions[0, 9:12]], dim=0) |
| 380 | + wrist_pos_translated = wrist_pos_world - robot_base_pos |
| 381 | + robot_base_quat_batch = robot_base_quat.unsqueeze(0).expand(2, -1) |
| 382 | + wrist_pos_base = math_utils.quat_apply_inverse( |
| 383 | + robot_base_quat_batch, wrist_pos_translated |
| 384 | + ) |
| 385 | + |
| 386 | + wrist_quat_world = torch.stack([actions[0, 5:9], actions[0, 12:16]], dim=0) |
| 387 | + robot_base_quat_inv = math_utils.quat_inv(robot_base_quat.unsqueeze(0)).expand(2, -1) |
| 388 | + wrist_quat_base = math_utils.quat_mul(robot_base_quat_inv, wrist_quat_world) |
| 389 | + |
| 390 | + actions[0, 2:5] = wrist_pos_base[0] |
| 391 | + actions[0, 5:9] = wrist_quat_base[0] |
| 392 | + actions[0, 9:12] = wrist_pos_base[1] |
| 393 | + actions[0, 12:16] = wrist_quat_base[1] |
| 394 | + |
| 395 | + return actions |
0 commit comments