Skip to content

Commit 5eafa21

Browse files
committed
Fix sizing of ones padded to position for homogenous coordinates
1 parent 2a0ebf3 commit 5eafa21

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,13 @@ def __init__(
185185
self._matrix = matrix.view(-1, 4, 4)
186186

187187
if pos is not None:
188-
ones = torch.ones(1, dtype=dtype, device=device)
188+
ones = torch.ones([1], dtype=dtype, device=device)
189189
if not torch.is_tensor(pos):
190190
pos = torch.tensor(pos, dtype=dtype, device=device)
191-
if pos.ndim in (2, 3) and pos.shape[0] > 1 and self._matrix.shape[0] == 1:
192-
self._matrix = self._matrix.repeat(pos.shape[0], 1, 1)
193-
ones = ones.repeat(pos.shape[0], 1)
194-
# self._matrix[:, :3, 3] = pos
191+
if pos.ndim in (2, 3):
192+
ones = ones.repeat(*pos.shape[:-1], 1)
193+
if pos.ndim in (2, 3) and pos.shape[0] > 1 and self._matrix.shape[0] == 1:
194+
self._matrix = self._matrix.repeat(pos.shape[0], 1, 1)
195195
pos_h = torch.cat((pos, ones), dim=-1).reshape(-1, 4, 1)
196196
self._matrix = torch.cat((self._matrix[:, :, :3], pos_h), dim=-1)
197197

tests/test_jacobian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def get_pt(th):
176176
pk_end = timer()
177177
# we can only compare the positional parts
178178
assert torch.allclose(j1_, j2[:, :3], atol=1e-6)
179-
print(f"for N={N} on {d} autograd:{(pk_start - autograd_start) * 1000}ms "
180-
f"pytorch-kinematics:{(pk_end - pk_start) * 1000}ms")
179+
print(f"for N={N} on {d} autograd:{(pk_start - autograd_start) * 1000}ms")
180+
print(f"for N={N} on {d} pytorch-kinematics:{(pk_end - pk_start) * 1000}ms")
181181
# if we have functools (for pytorch>=1.13.0 it comes with installing pytorch)
182182
try:
183183
import functorch

0 commit comments

Comments
 (0)