Skip to content

Commit 730fa40

Browse files
committed
Remove laziness to speed up computations by about 25%
1 parent b8b07db commit 730fa40

File tree

1 file changed

+9
-41
lines changed

1 file changed

+9
-41
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212
rot_h = torch.cat((rot, zeros), dim=-2).reshape(-1, 4, 3)
213213
self._matrix = torch.cat((rot_h, self._matrix[:, :, 3].reshape(-1, 4, 1)), dim=-1)
214214

215-
self._transforms = [] # store transforms to compose
216215
self._lu = None
217216
self.device = device
218217
self.dtype = self._matrix.dtype
@@ -240,13 +239,12 @@ def compose(self, *others):
240239
Returns:
241240
A new Transform3d with the stored transforms
242241
"""
243-
out = Transform3d(device=self.device, dtype=self.dtype)
244-
out._matrix = self._matrix.clone()
242+
243+
mat = self._matrix
245244
for other in others:
246-
if not isinstance(other, Transform3d):
247-
msg = "Only possible to compose Transform3d objects; got %s"
248-
raise ValueError(msg % type(other))
249-
out._transforms = self._transforms + list(others)
245+
mat = _broadcast_bmm(mat, other.get_matrix())
246+
247+
out = Transform3d(device=self.device, dtype=self.dtype, matrix=mat)
250248
return out
251249

252250
def get_matrix(self):
@@ -266,11 +264,7 @@ def get_matrix(self):
266264
Returns:
267265
A transformation matrix representing the composed inputs.
268266
"""
269-
composed_matrix = self._matrix
270-
for other in self._transforms:
271-
other_matrix = other.get_matrix()
272-
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
273-
return composed_matrix
267+
return self._matrix
274268

275269
def _get_matrix_inverse(self):
276270
"""
@@ -311,40 +305,16 @@ def inverse(self, invert_composed: bool = False):
311305
transformation.
312306
"""
313307

314-
tinv = Transform3d(device=self.device)
308+
i_matrix = self._get_matrix_inverse()
315309

316-
if invert_composed:
317-
# first compose then invert
318-
tinv._matrix = self._invert_transformation_matrix(self.get_matrix())
319-
else:
320-
# self._get_matrix_inverse() implements efficient inverse
321-
# of self._matrix
322-
i_matrix = self._get_matrix_inverse()
323-
324-
# 2 cases:
325-
if len(self._transforms) > 0:
326-
# a) Either we have a non-empty list of transforms:
327-
# Here we take self._matrix and append its inverse at the
328-
# end of the reverted _transforms list. After composing
329-
# the transformations with get_matrix(), this correctly
330-
# right-multiplies by the inverse of self._matrix
331-
# at the end of the composition.
332-
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
333-
last = Transform3d(device=self.device)
334-
last._matrix = i_matrix
335-
tinv._transforms.append(last)
336-
else:
337-
# b) Or there are no stored transformations
338-
# we just set inverted matrix
339-
tinv._matrix = i_matrix
310+
tinv = Transform3d(matrix=i_matrix, device=self.device)
340311

341312
return tinv
342313

343314
def stack(self, *others):
344315
transforms = [self] + list(others)
345316
matrix = torch.cat([t._matrix for t in transforms], dim=0)
346-
out = Transform3d()
347-
out._matrix = matrix
317+
out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype)
348318
return out
349319

350320
def transform_points(self, points, eps: Optional[float] = None):
@@ -478,7 +448,6 @@ def clone(self):
478448
if self._lu is not None:
479449
other._lu = [elem.clone() for elem in self._lu]
480450
other._matrix = self._matrix.clone()
481-
other._transforms = [t.clone() for t in self._transforms]
482451
return other
483452

484453
def to(self, device, copy: bool = False, dtype=None):
@@ -504,7 +473,6 @@ def to(self, device, copy: bool = False, dtype=None):
504473
other.device = device
505474
other.dtype = dtype if dtype is not None else other.dtype
506475
other._matrix = self._matrix.to(device=device, dtype=dtype)
507-
other._transforms = [t.to(device, copy=copy, dtype=dtype) for t in other._transforms]
508476
return other
509477

510478
def cpu(self):

0 commit comments

Comments
 (0)