Skip to content

Commit ba26d7a

Browse files
committed
G1 preprocess_action applies to batch
1 parent 6017c38 commit ba26d7a

File tree

2 files changed

+207
-14
lines changed

2 files changed

+207
-14
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
"""Unit tests for G1 WBC Pink action preprocess_actions (world → robot base frame)."""
7+
8+
import torch
9+
10+
from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function
11+
12+
HEADLESS = True
13+
14+
15+
def _get_g1_pink_env_and_term(simulation_app):
16+
"""Build G1 WBC Pink env at origin with identity orientation; return env and g1_action term."""
17+
from isaaclab_arena.assets.asset_registry import AssetRegistry
18+
from isaaclab_arena.cli.isaaclab_arena_cli import get_isaaclab_arena_cli_parser
19+
from isaaclab_arena.embodiments.g1.g1 import G1WBCPinkEmbodiment
20+
from isaaclab_arena.environments.arena_env_builder import ArenaEnvBuilder
21+
from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment
22+
from isaaclab_arena.scene.scene import Scene
23+
from isaaclab_arena.utils.pose import Pose
24+
25+
asset_registry = AssetRegistry()
26+
background = asset_registry.get_asset_by_name("kitchen")()
27+
scene = Scene(assets=[background])
28+
embodiment = G1WBCPinkEmbodiment(enable_cameras=False)
29+
embodiment.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_wxyz=(1.0, 0.0, 0.0, 0.0)))
30+
isaaclab_arena_environment = IsaacLabArenaEnvironment(
31+
name="g1_pink_preprocess_test",
32+
embodiment=embodiment,
33+
scene=scene,
34+
)
35+
args_cli = get_isaaclab_arena_cli_parser().parse_args([])
36+
env_builder = ArenaEnvBuilder(isaaclab_arena_environment, args_cli)
37+
env = env_builder.make_registered()
38+
env.reset()
39+
term = env.unwrapped.action_manager.get_term("g1_action")
40+
return env, term
41+
42+
43+
def _test_preprocess_actions_shape(simulation_app) -> bool:
44+
"""preprocess_actions preserves shape (num_envs, action_dim)."""
45+
env, term = _get_g1_pink_env_and_term(simulation_app)
46+
try:
47+
action_dim = term.action_dim
48+
num_envs = env.num_envs
49+
actions = torch.zeros(num_envs, action_dim, device=env.unwrapped.device)
50+
out = term.preprocess_actions(actions)
51+
assert out.shape == (num_envs, action_dim), f"Expected shape ({num_envs}, {action_dim}), got {out.shape}"
52+
finally:
53+
env.close()
54+
return True
55+
56+
57+
def _test_preprocess_actions_identity_base(simulation_app) -> bool:
58+
"""When robot base has identity quat, wrist in base frame = world pos minus base pos."""
59+
env, term = _get_g1_pink_env_and_term(simulation_app)
60+
try:
61+
device = env.unwrapped.device
62+
action_dim = term.action_dim
63+
robot_base_pos = term._asset.data.root_link_pos_w[0, :3]
64+
robot_base_quat = term._asset.data.root_link_quat_w[0]
65+
66+
# World-frame wrist positions: base + offset (so base-frame offset is known)
67+
left_offset = torch.tensor([1.0, 2.0, 3.0], device=device)
68+
right_offset = torch.tensor([4.0, 5.0, 6.0], device=device)
69+
left_pos_world = robot_base_pos + left_offset
70+
right_pos_world = robot_base_pos + right_offset
71+
72+
actions = torch.zeros(1, action_dim, device=device)
73+
actions[0, 2:5] = left_pos_world
74+
actions[0, 5:9] = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
75+
actions[0, 9:12] = right_pos_world
76+
actions[0, 12:16] = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
77+
78+
out = term.preprocess_actions(actions)
79+
80+
# Base frame position = world - base (in world), then rotated by base_inv => offset when base quat is identity
81+
torch.testing.assert_close(out[0, 2:5], left_offset, atol=1e-4, rtol=0)
82+
torch.testing.assert_close(out[0, 9:12], right_offset, atol=1e-4, rtol=0)
83+
torch.testing.assert_close(out[0, 5:9], actions[0, 5:9], atol=1e-5, rtol=0)
84+
torch.testing.assert_close(out[0, 12:16], actions[0, 12:16], atol=1e-5, rtol=0)
85+
finally:
86+
env.close()
87+
return True
88+
89+
90+
def _test_preprocess_actions_roundtrip(simulation_app) -> bool:
91+
"""Preprocess world→base; then base→world recovers original (using current robot pose)."""
92+
import isaaclab.utils.math as math_utils
93+
94+
env, term = _get_g1_pink_env_and_term(simulation_app)
95+
try:
96+
device = env.unwrapped.device
97+
action_dim = term.action_dim
98+
asset = term._asset
99+
100+
robot_base_pos = asset.data.root_link_pos_w[:, :3]
101+
robot_base_quat = asset.data.root_link_quat_w
102+
num_envs = robot_base_pos.shape[0]
103+
104+
# Arbitrary world-frame wrist poses
105+
left_pos_w = torch.tensor([[1.0, 0.0, 0.5]], device=device).expand(num_envs, 3)
106+
left_quat_w = torch.tensor([[1.0, 0.0, 0.0, 0.0]], device=device).expand(num_envs, 4)
107+
right_pos_w = torch.tensor([[0.0, 1.0, 0.5]], device=device).expand(num_envs, 3)
108+
right_quat_w = torch.tensor([[1.0, 0.0, 0.0, 0.0]], device=device).expand(num_envs, 4)
109+
110+
actions = torch.zeros(num_envs, action_dim, device=device)
111+
actions[:, 2:5] = left_pos_w
112+
actions[:, 5:9] = left_quat_w
113+
actions[:, 9:12] = right_pos_w
114+
actions[:, 12:16] = right_quat_w
115+
116+
out = term.preprocess_actions(actions)
117+
left_pos_b = out[:, 2:5]
118+
left_quat_b = out[:, 5:9]
119+
right_pos_b = out[:, 9:12]
120+
right_quat_b = out[:, 12:16]
121+
122+
# Base → world: pos_w = base_pos + quat_apply(base_quat, pos_b), quat_w = quat_mul(base_quat, quat_b)
123+
left_pos_w_recovered = robot_base_pos + math_utils.quat_apply(robot_base_quat, left_pos_b)
124+
left_quat_w_recovered = math_utils.quat_mul(robot_base_quat, left_quat_b)
125+
right_pos_w_recovered = robot_base_pos + math_utils.quat_apply(robot_base_quat, right_pos_b)
126+
right_quat_w_recovered = math_utils.quat_mul(robot_base_quat, right_quat_b)
127+
128+
torch.testing.assert_close(left_pos_w_recovered, left_pos_w, atol=1e-5, rtol=0)
129+
torch.testing.assert_close(left_quat_w_recovered, left_quat_w, atol=1e-5, rtol=0)
130+
torch.testing.assert_close(right_pos_w_recovered, right_pos_w, atol=1e-5, rtol=0)
131+
torch.testing.assert_close(right_quat_w_recovered, right_quat_w, atol=1e-5, rtol=0)
132+
finally:
133+
env.close()
134+
return True
135+
136+
137+
def _test_preprocess_actions_does_not_mutate_other_slots(simulation_app) -> bool:
138+
"""Indices outside wrist pos/quat (e.g. hand state, navigate_cmd) are unchanged."""
139+
env, term = _get_g1_pink_env_and_term(simulation_app)
140+
try:
141+
device = env.unwrapped.device
142+
action_dim = term.action_dim
143+
actions = torch.zeros(1, action_dim, device=device)
144+
actions[0, 0] = 0.5
145+
actions[0, 1] = 0.7
146+
actions[0, 16:19] = torch.tensor([0.1, 0.2, 0.3], device=device)
147+
actions[0, 19] = 0.75
148+
actions[0, 20:23] = torch.tensor([0.0, 0.0, 0.1], device=device)
149+
150+
out = term.preprocess_actions(actions)
151+
152+
torch.testing.assert_close(out[0, 0], torch.tensor(0.5, device=device), atol=1e-6, rtol=0)
153+
torch.testing.assert_close(out[0, 1], torch.tensor(0.7, device=device), atol=1e-6, rtol=0)
154+
torch.testing.assert_close(out[0, 16:19], actions[0, 16:19], atol=1e-6, rtol=0)
155+
torch.testing.assert_close(out[0, 19], actions[0, 19], atol=1e-6, rtol=0)
156+
torch.testing.assert_close(out[0, 20:23], actions[0, 20:23], atol=1e-6, rtol=0)
157+
finally:
158+
env.close()
159+
return True
160+
161+
162+
def test_g1_wbc_pink_preprocess_actions_shape():
163+
result = run_simulation_app_function(
164+
_test_preprocess_actions_shape,
165+
headless=HEADLESS,
166+
)
167+
assert result, "preprocess_actions shape test failed"
168+
169+
170+
def test_g1_wbc_pink_preprocess_actions_identity_base():
171+
result = run_simulation_app_function(
172+
_test_preprocess_actions_identity_base,
173+
headless=HEADLESS,
174+
)
175+
assert result, "preprocess_actions identity base test failed"
176+
177+
178+
def test_g1_wbc_pink_preprocess_actions_roundtrip():
179+
result = run_simulation_app_function(
180+
_test_preprocess_actions_roundtrip,
181+
headless=HEADLESS,
182+
)
183+
assert result, "preprocess_actions roundtrip test failed"
184+
185+
186+
def test_g1_wbc_pink_preprocess_actions_does_not_mutate_other_slots():
187+
result = run_simulation_app_function(
188+
_test_preprocess_actions_does_not_mutate_other_slots,
189+
headless=HEADLESS,
190+
)
191+
assert result, "preprocess_actions other slots test failed"

isaaclab_arena_g1/g1_env/mdp/actions/g1_decoupled_wbc_pink_action.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -372,28 +372,30 @@ def preprocess_actions(self, actions: torch.Tensor) -> torch.Tensor:
372372
"""Transform wrist positions and orientations from world frame to robot base frame.
373373
374374
Args:
375-
actions: The input actions tensor, shape (action_dim,) or (1, action_dim).
375+
actions: The input actions tensor, shape (num_envs, action_dim).
376376
377377
Returns:
378378
The processed actions tensor (same shape as input).
379379
"""
380380
actions = actions.clone()
381381

382-
robot_base_pos = self._asset.data.root_link_pos_w[0, :3]
383-
robot_base_quat = self._asset.data.root_link_quat_w[0]
382+
robot_base_pos = self._asset.data.root_link_pos_w[:, :3]
383+
robot_base_quat = self._asset.data.root_link_quat_w
384384

385-
wrist_pos_world = torch.stack([actions[0, 2:5], actions[0, 9:12]], dim=0)
386-
wrist_pos_translated = wrist_pos_world - robot_base_pos
387-
robot_base_quat_batch = robot_base_quat.unsqueeze(0).expand(2, -1)
388-
wrist_pos_base = math_utils.quat_apply_inverse(robot_base_quat_batch, wrist_pos_translated)
385+
left_wrist_pos_world = actions[:, 2:5]
386+
right_wrist_pos_world = actions[:, 9:12]
387+
left_wrist_pos_base = math_utils.quat_apply_inverse(robot_base_quat, left_wrist_pos_world - robot_base_pos)
388+
right_wrist_pos_base = math_utils.quat_apply_inverse(robot_base_quat, right_wrist_pos_world - robot_base_pos)
389389

390-
wrist_quat_world = torch.stack([actions[0, 5:9], actions[0, 12:16]], dim=0)
391-
robot_base_quat_inv = math_utils.quat_inv(robot_base_quat.unsqueeze(0)).expand(2, -1)
392-
wrist_quat_base = math_utils.quat_mul(robot_base_quat_inv, wrist_quat_world)
390+
left_wrist_quat_world = actions[:, 5:9]
391+
right_wrist_quat_world = actions[:, 12:16]
392+
robot_base_quat_inv = math_utils.quat_inv(robot_base_quat)
393+
left_wrist_quat_base = math_utils.quat_mul(robot_base_quat_inv, left_wrist_quat_world)
394+
right_wrist_quat_base = math_utils.quat_mul(robot_base_quat_inv, right_wrist_quat_world)
393395

394-
actions[0, 2:5] = wrist_pos_base[0]
395-
actions[0, 5:9] = wrist_quat_base[0]
396-
actions[0, 9:12] = wrist_pos_base[1]
397-
actions[0, 12:16] = wrist_quat_base[1]
396+
actions[:, 2:5] = left_wrist_pos_base
397+
actions[:, 5:9] = left_wrist_quat_base
398+
actions[:, 9:12] = right_wrist_pos_base
399+
actions[:, 12:16] = right_wrist_quat_base
398400

399401
return actions

0 commit comments

Comments
 (0)