Skip to content

Commit b21e7c8

Browse files
Update src/pytorch_kinematics/transforms/rotation_conversions.py
Co-authored-by: Tom Power <[email protected]>
1 parent 5794eb5 commit b21e7c8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,10 @@ def tensor_axis_and_d_to_pris_matrix(axis, d):
462462
463463
"""
464464
batch_axes = axis.shape[:-1]
465-
mat44 = torch.eye(4).to(axis).repeat(*batch_axes, 1, 1)
466-
pos = axis * d[..., None]
467-
mat44[..., :3, 3] = pos
465+
mat33 = torch.eye(3).to(axis).expand(*batch_axes, 3, 3)
466+
pos = axis * d.unsqueeze(-1)
467+
mat44 = torch.cat((mat33, pos.unsqueeze(-1)), -1)
468+
mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_axes, 1, 4).to(axis)), -2)
468469
return mat44
469470

470471

0 commit comments

Comments
 (0)