Skip to content

Commit fb4f7db

Browse files
committed
added head pose transformation to action
1 parent 19341f9 commit fb4f7db

File tree

1 file changed

+176
-25
lines changed

1 file changed

+176
-25
lines changed

src/opentau/scripts/RecordHuman_to_lerobot.py

Lines changed: 176 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
[pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w]) for left hand (182)
1919
concatenated with right hand (182). Zeros when a hand is not tracked.
2020
21-
Action: the next frame's state (last frame repeats its own state).
21+
Action vector (371-D): next frame's state (364) concatenated with the delta head
22+
pose (7 = delta_pos(3) + delta_rot(4)). Delta position is pos_{t+1} - pos_{t}.
23+
Delta rotation is the relative quaternion q_{t+1} * q_t^{-1}. The last frame
24+
uses zero delta position and identity quaternion [0,0,0,1].
2225
2326
Video frames are stored as observation.images.camera.
2427
@@ -50,6 +53,8 @@
5053
FLOATS_PER_JOINT = 7 # pos(3) + rot(4)
5154
STATE_DIM_PER_HAND = JOINTS_PER_HAND * FLOATS_PER_JOINT # 182
5255
STATE_DIM = STATE_DIM_PER_HAND * 2 # 364: left(182) + right(182)
56+
HEAD_POSE_DIM = 7 # pos(3) + rot(4)
57+
ACTION_DIM = STATE_DIM + HEAD_POSE_DIM # 371: next-state(364) + delta head pose(7)
5358

5459
FINGERTIP_INDICES = {5, 10, 15, 20, 25}
5560

@@ -95,6 +100,18 @@
95100

96101

97102
def quat_to_rotation_matrix(q):
103+
"""Convert an (x, y, z, w) quaternion to a 3x3 rotation matrix.
104+
105+
The returned matrix R rotates column vectors from the local frame into the
106+
world frame: ``p_world = R @ p_local``.
107+
108+
Args:
109+
q: Quaternion as a 4-element array-like in (x, y, z, w) order.
110+
Normalized internally.
111+
112+
Returns:
113+
A (3, 3) numpy rotation matrix.
114+
"""
98115
x, y, z, w = q / np.linalg.norm(q)
99116
return np.array(
100117
[
@@ -106,7 +123,22 @@ def quat_to_rotation_matrix(q):
106123

107124

108125
def rotation_matrix_to_quat(R):
109-
"""Convert a 3×3 rotation matrix to (x, y, z, w) quaternion."""
126+
"""Convert a 3x3 rotation matrix to an (x, y, z, w) quaternion.
127+
128+
Uses Shepperd's method, which selects the numerically stable branch
129+
based on the matrix diagonal.
130+
131+
Note:
132+
The sign of the returned quaternion is arbitrary (q and -q represent
133+
the same rotation). Avoid round-tripping through this function when
134+
sign consistency matters; use ``quat_inverse`` directly instead.
135+
136+
Args:
137+
R: A (3, 3) rotation matrix.
138+
139+
Returns:
140+
A length-4 numpy array ``[x, y, z, w]``.
141+
"""
110142
trace = R[0, 0] + R[1, 1] + R[2, 2]
111143
if trace > 0:
112144
s = 0.5 / np.sqrt(trace + 1.0)
@@ -135,20 +167,32 @@ def rotation_matrix_to_quat(R):
135167
return np.array([x, y, z, w])
136168

137169

138-
def rotate_quaternions_by_matrix(R, quats):
139-
"""Rotate an array of (N, 4) quaternions (x,y,z,w) by rotation matrix R.
170+
def quat_inverse(q):
171+
"""Return the inverse of a unit quaternion (its conjugate).
140172
141-
Computes q_out = R_as_quat * q_in for each row.
173+
Args:
174+
q: Unit quaternion as a 4-element array-like in (x, y, z, w) order.
175+
176+
Returns:
177+
A length-4 numpy array ``[-x, -y, -z, w]``.
142178
"""
143-
r_q = rotation_matrix_to_quat(R)
144-
out = np.empty_like(quats)
145-
for i in range(len(quats)):
146-
out[i] = quat_multiply(r_q, quats[i])
147-
return out
179+
x, y, z, w = q
180+
return np.array([-x, -y, -z, w])
148181

149182

150183
def quat_multiply(q1, q2):
151-
"""Hamilton product of two (x,y,z,w) quaternions."""
184+
"""Compute the Hamilton product of two quaternions.
185+
186+
The result represents rotation q1 applied *after* q2:
187+
``q_out = q1 * q2`` means "first rotate by q2, then by q1".
188+
189+
Args:
190+
q1: Left quaternion, (x, y, z, w).
191+
q2: Right quaternion, (x, y, z, w).
192+
193+
Returns:
194+
A length-4 numpy array ``[x, y, z, w]``.
195+
"""
152196
x1, y1, z1, w1 = q1
153197
x2, y2, z2, w2 = q2
154198
return np.array(
@@ -162,6 +206,23 @@ def quat_multiply(q1, q2):
162206

163207

164208
def world_to_camera(points_tracking, head_pos_tracking, head_rot_q, eye_offset=None):
209+
"""Transform 3-D points from tracking (world) space to camera (head-local) space.
210+
211+
Computes ``p_cam = R_world_to_cam @ (p_world - eye_pos)`` using row-vector
212+
convention internally (``(N,3) @ (3,3)``).
213+
214+
Args:
215+
points_tracking: (N, 3) array of points in tracking/world coordinates.
216+
head_pos_tracking: (3,) head position in tracking space (after origin
217+
subtraction).
218+
head_rot_q: (4,) head orientation quaternion (x, y, z, w) mapping local
219+
to world.
220+
eye_offset: Optional (3,) offset from head center to the eye in head-
221+
local coordinates. Applied after rotating by head orientation.
222+
223+
Returns:
224+
(N, 3) array of points in camera space.
225+
"""
165226
R_local_to_world = quat_to_rotation_matrix(head_rot_q)
166227
R_world_to_local = R_local_to_world.T
167228
eye_pos = head_pos_tracking
@@ -171,6 +232,22 @@ def world_to_camera(points_tracking, head_pos_tracking, head_rot_q, eye_offset=N
171232

172233

173234
def project_camera_to_pixel(points_cam, fx, fy, cx, cy):
235+
"""Project camera-space 3-D points to 2-D pixel coordinates (pinhole model).
236+
237+
Y is negated so that camera-up maps to image-down (standard image coords).
238+
Points behind the camera (Z <= 0.01) are marked invalid and set to NaN.
239+
240+
Args:
241+
points_cam: (N, 3) array in camera coordinates (Z forward, Y up).
242+
fx: Horizontal focal length in pixels.
243+
fy: Vertical focal length in pixels.
244+
cx: Principal point x (pixels).
245+
cy: Principal point y (pixels).
246+
247+
Returns:
248+
Tuple of ``(px, py, valid)`` where ``px`` and ``py`` are (N,) float
249+
arrays (NaN for invalid points), and ``valid`` is a boolean mask.
250+
"""
174251
valid = points_cam[:, 2] > 0.01
175252
px = np.full(len(points_cam), np.nan)
176253
py = np.full(len(points_cam), np.nan)
@@ -181,12 +258,27 @@ def project_camera_to_pixel(points_cam, fx, fy, cx, cy):
181258

182259

183260
def parse_hand_joints(flat_array):
184-
"""Return (26, 7) array of [pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w]."""
261+
"""Parse a flat joint array into a structured (26, 7) array.
262+
263+
Args:
264+
flat_array: Flat sequence of 182 floats (26 joints x 7 values per
265+
joint: pos_x, pos_y, pos_z, quat_x, quat_y, quat_z, quat_w).
266+
267+
Returns:
268+
(26, 7) float64 numpy array.
269+
"""
185270
return np.array(flat_array, dtype=np.float64).reshape(JOINTS_PER_HAND, FLOATS_PER_JOINT)
186271

187272

188273
def parse_hand_positions(flat_array):
189-
"""Return (26, 3) positions only — used for overlay projection."""
274+
"""Parse a flat joint array and return only the 3-D positions.
275+
276+
Args:
277+
flat_array: Flat sequence of 182 floats (see ``parse_hand_joints``).
278+
279+
Returns:
280+
(26, 3) float64 numpy array of joint positions.
281+
"""
190282
return parse_hand_joints(flat_array)[:, :3]
191283

192284

@@ -196,6 +288,14 @@ def parse_hand_positions(flat_array):
196288

197289

198290
def draw_hand(frame, positions_2d, valid, color):
291+
"""Draw a hand skeleton (bones and joints) onto an image.
292+
293+
Args:
294+
frame: BGR image (H, W, 3), modified in place.
295+
positions_2d: (26, 2) array of pixel coordinates per joint.
296+
valid: (26,) boolean mask indicating which joints are visible.
297+
color: BGR tuple for the skeleton color.
298+
"""
199299
h, w = frame.shape[:2]
200300
for a, b in SKELETON_BONES:
201301
if not (valid[a] and valid[b]):
@@ -216,6 +316,16 @@ def draw_hand(frame, positions_2d, valid, color):
216316

217317

218318
def draw_head_gizmo(frame, head_rot_q):
319+
"""Draw a 3-axis orientation gizmo showing how world axes appear in camera view.
320+
321+
Renders X (red), Y (green), Z (blue) arrows in the upper-left corner of
322+
the frame. Each arrow shows the direction of the corresponding world axis
323+
as seen from the head's local frame.
324+
325+
Args:
326+
frame: BGR image (H, W, 3), modified in place.
327+
head_rot_q: (4,) head orientation quaternion (x, y, z, w), local-to-world.
328+
"""
219329
R = quat_to_rotation_matrix(head_rot_q)
220330
Rw2l = R.T
221331
ox, oy = 60, 60
@@ -242,6 +352,17 @@ def draw_head_gizmo(frame, head_rot_q):
242352

243353

244354
def draw_hud(frame, video_time, pose_time, frame_idx, total_frames, left_tracked, right_tracked):
355+
"""Draw a heads-up display with timing and hand-tracking status.
356+
357+
Args:
358+
frame: BGR image (H, W, 3), modified in place.
359+
video_time: Current video timestamp in seconds.
360+
pose_time: Matched pose-data timestamp in seconds.
361+
frame_idx: Current video frame index.
362+
total_frames: Total number of video frames.
363+
left_tracked: Whether the left hand is currently tracked.
364+
right_tracked: Whether the right hand is currently tracked.
365+
"""
245366
h, w = frame.shape[:2]
246367
info = f"t={video_time:.2f}s pose_t={pose_time:.2f}s frame {frame_idx}/{total_frames}"
247368
cv2.putText(frame, info, (10, h - 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1, cv2.LINE_AA)
@@ -262,17 +383,39 @@ def draw_hud(frame, video_time, pose_time, frame_idx, total_frames, left_tracked
262383

263384

264385
def compute_frame_state(pf, head_pos_tracking, head_rot, eye_offset):
265-
"""Return (state_vector[364], left_cam_pos[26,3]|None, right_cam_pos[26,3]|None).
266-
267-
State layout per hand (182 floats):
268-
joint_0_pos(3) joint_0_quat(4) joint_1_pos(3) joint_1_quat(4) ... joint_25_pos(3) joint_25_quat(4)
269-
Both positions and quaternions are in camera space.
386+
"""Build the 364-D camera-space state vector for one pose frame.
387+
388+
Each hand contributes 182 floats (26 joints x 7: pos(3) + quat(4)), all
389+
expressed in head/camera space. An untracked hand is left as zeros.
390+
391+
Joint positions are transformed via ``world_to_camera``. Joint quaternions
392+
are rotated by ``q_head^{-1}`` (directly, without a matrix round-trip) to
393+
convert from world orientation to camera-local orientation.
394+
395+
Note:
396+
``head_pos_tracking`` is expected to already have the XR tracking
397+
origin subtracted. The raw joint positions from ``pf`` are used
398+
as-is because the Pico recording format stores them in a coordinate
399+
system that does not require the same origin correction.
400+
401+
Args:
402+
pf: Single frame dict from the pose JSON (must contain ``left_joints``
403+
/ ``right_joints`` and ``left_tracked`` / ``right_tracked``).
404+
head_pos_tracking: (3,) head position in tracking space (origin-
405+
subtracted).
406+
head_rot: (4,) head orientation quaternion (x, y, z, w), local-to-world.
407+
eye_offset: Optional (3,) eye offset in head-local coords, or ``None``.
408+
409+
Returns:
410+
Tuple of ``(state, left_cam_pos, right_cam_pos)`` where ``state`` is a
411+
float32 array of shape ``(364,)``, and each ``*_cam_pos`` is either a
412+
``(26, 3)`` float64 array of camera-space joint positions or ``None``
413+
if that hand is untracked.
270414
"""
271415
state = np.zeros(STATE_DIM, dtype=np.float32)
272416
left_cam_pos = right_cam_pos = None
273417

274-
R_head = quat_to_rotation_matrix(head_rot)
275-
R_inv = R_head.T
418+
q_head_inv = quat_inverse(head_rot)
276419

277420
for hand_idx, (jkey, tkey) in enumerate(
278421
[
@@ -288,9 +431,7 @@ def compute_frame_state(pf, head_pos_tracking, head_rot, eye_offset):
288431

289432
cam_pos = world_to_camera(pos_tracking, head_pos_tracking, head_rot, eye_offset)
290433

291-
# Rotate each joint quaternion into camera frame:
292-
# q_cam = q_head_inv * q_joint
293-
cam_quat = rotate_quaternions_by_matrix(R_inv, quat)
434+
cam_quat = np.array([quat_multiply(q_head_inv, quat[i]) for i in range(len(quat))])
294435

295436
hand_state = np.hstack([cam_pos, cam_quat]) # (26, 7)
296437
offset = hand_idx * STATE_DIM_PER_HAND
@@ -310,6 +451,7 @@ def compute_frame_state(pf, head_pos_tracking, head_rot, eye_offset):
310451

311452

312453
def main():
454+
"""Entry point: parse CLI args, process video + poses, write LeRobot dataset."""
313455
parser = argparse.ArgumentParser(
314456
description="Convert Pico recordings to a LeRobot dataset (with optional overlay video).",
315457
formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -412,6 +554,7 @@ def main():
412554
overlay_writer = cv2.VideoWriter(str(overlay_path), fourcc, video_fps, (video_w, video_h))
413555

414556
states: list[np.ndarray] = []
557+
head_poses: list[np.ndarray] = []
415558
cap = cv2.VideoCapture(str(video_path))
416559

417560
if overlay_writer:
@@ -455,6 +598,7 @@ def main():
455598

456599
if next_sample < len(sampled_indices) and frame_idx == sampled_indices[next_sample]:
457600
states.append(state)
601+
head_poses.append(np.concatenate([head_pos_tracking, head_rot]).astype(np.float32))
458602
next_sample += 1
459603

460604
overlay_writer.release()
@@ -480,6 +624,7 @@ def main():
480624

481625
state, _, _ = compute_frame_state(pf, head_pos_tracking, head_rot, eye_offset)
482626
states.append(state)
627+
head_poses.append(np.concatenate([head_pos_tracking, head_rot]).astype(np.float32))
483628

484629
cap.release()
485630

@@ -488,7 +633,13 @@ def main():
488633
print("Error: no frames processed.", file=sys.stderr)
489634
sys.exit(1)
490635

491-
actions = [states[i + 1].copy() for i in range(num_frames - 1)] + [states[-1].copy()]
636+
actions = []
637+
for i in range(num_frames - 1):
638+
delta_pos = head_poses[i + 1][:3] - head_poses[i][:3]
639+
delta_rot = quat_multiply(head_poses[i + 1][3:], quat_inverse(head_poses[i][3:]))
640+
actions.append(np.concatenate([states[i + 1], delta_pos, delta_rot]).astype(np.float32))
641+
delta_zero = np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.float32)
642+
actions.append(np.concatenate([states[-1], delta_zero]))
492643

493644
# --- Create LeRobot dataset ---
494645
image_key = "observation.images.camera"
@@ -508,7 +659,7 @@ def main():
508659
},
509660
action_key: {
510661
"dtype": "float32",
511-
"shape": (STATE_DIM,),
662+
"shape": (ACTION_DIM,),
512663
"names": None,
513664
},
514665
}

0 commit comments

Comments
 (0)