Skip to content

Commit 3ce4dcf

Browse files
committed
Fix initializing transform with a batch rotation matrix
1 parent 164cef7 commit 3ce4dcf

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,9 @@ def __init__(
204204
elif rot.shape[-1] == 3 and (len(rot.shape) == 1 or rot.shape[-2] != 3):
205205
rot = euler_angles_to_matrix(rot, DEFAULT_EULER_CONVENTION)
206206
if rot.ndim == 3:
207-
zeros = zeros.unsqueeze(0)
207+
zeros = zeros.repeat(rot.shape[0], 1, 1)
208208
if rot.shape[0] > 1 and self._matrix.shape[0] == 1:
209209
self._matrix = self._matrix.repeat(rot.shape[0], 1, 1)
210-
zeros = zeros.repeat(rot.shape[0], 1, 1)
211210
rot_h = torch.cat((rot, zeros), dim=-2).reshape(-1, 4, 3)
212211
self._matrix = torch.cat((rot_h, self._matrix[:, :, 3].reshape(-1, 4, 1)), dim=-1)
213212

0 commit comments

Comments
 (0)