Skip to content

Commit a2e1274

Browse files
committed
Add SE(3) conversion between 9D continuous representation and 4x4 matrix
1 parent caf2df3 commit a2e1274

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/pytorch_kinematics/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
axis_and_d_to_pris_matrix,
2424
wxyz_to_xyzw,
2525
xyzw_to_wxyz,
26+
matrix44_to_se3_9d,
27+
se3_9d_to_matrix44,
2628
)
2729
from .so3 import (
2830
so3_exp_map,

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,21 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
662662
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
663663

664664

665+
def matrix44_to_se3_9d(matrix: torch.Tensor) -> torch.Tensor:
666+
r = matrix_to_rotation_6d(matrix[..., :3, :3])
667+
t = matrix[..., :3, 3]
668+
return torch.cat([r, t], dim=-1)
669+
670+
671+
def se3_9d_to_matrix44(se3: torch.Tensor) -> torch.Tensor:
672+
r = rotation_6d_to_matrix(se3[..., :6])
673+
t = se3[..., 6:]
674+
H = torch.eye(4, device=r.device, dtype=r.dtype).repeat(r.shape[:-2] + (1, 1))
675+
H[..., :3, :3] = r
676+
H[..., :3, 3] = t
677+
return H
678+
679+
665680
def matrix_to_pos_rot(m):
666681
"""Convert 4x4 transformation matrix to (position, xyzw quatnerion) used by pybullet and RViz"""
667682
pos = m[..., :3, 3]

0 commit comments

Comments
 (0)