Skip to content

Commit 31f1422

Browse files
authored
Merge pull request #27 from UM-ARM-Lab/optimize_transform3d
Remove laziness to speed up computations by about 25%
2 parents b8b07db + 0ebabfc commit 31f1422

File tree

1 file changed

+12
-64
lines changed

1 file changed

+12
-64
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212
rot_h = torch.cat((rot, zeros), dim=-2).reshape(-1, 4, 3)
213213
self._matrix = torch.cat((rot_h, self._matrix[:, :, 3].reshape(-1, 4, 1)), dim=-1)
214214

215-
self._transforms = [] # store transforms to compose
216215
self._lu = None
217216
self.device = device
218217
self.dtype = self._matrix.dtype
@@ -240,37 +239,19 @@ def compose(self, *others):
240239
Returns:
241240
A new Transform3d with the stored transforms
242241
"""
243-
out = Transform3d(device=self.device, dtype=self.dtype)
244-
out._matrix = self._matrix.clone()
242+
243+
mat = self._matrix
245244
for other in others:
246-
if not isinstance(other, Transform3d):
247-
msg = "Only possible to compose Transform3d objects; got %s"
248-
raise ValueError(msg % type(other))
249-
out._transforms = self._transforms + list(others)
245+
mat = _broadcast_bmm(mat, other.get_matrix())
246+
247+
out = Transform3d(device=self.device, dtype=self.dtype, matrix=mat)
250248
return out
251249

252250
def get_matrix(self):
253251
"""
254-
Return a matrix which is the result of composing this transform
255-
with others stored in self.transforms. Where necessary transforms
256-
are broadcast against each other.
257-
For example, if self.transforms contains transforms t1, t2, and t3, and
258-
given a set of points x, the following should be true:
259-
260-
.. code-block:: python
261-
262-
y1 = t1.compose(t2, t3).transform(x)
263-
y2 = t3.transform(t2.transform(t1.transform(x)))
264-
y1.get_matrix() == y2.get_matrix()
265-
266-
Returns:
267-
A transformation matrix representing the composed inputs.
252+
Return the Nx4x4 homogeneous transformation matrix represented by this object.
268253
"""
269-
composed_matrix = self._matrix
270-
for other in self._transforms:
271-
other_matrix = other.get_matrix()
272-
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
273-
return composed_matrix
254+
return self._matrix
274255

275256
def _get_matrix_inverse(self):
276257
"""
@@ -282,7 +263,7 @@ def _get_matrix_inverse(self):
282263
@staticmethod
283264
def _invert_transformation_matrix(T):
284265
"""
285-
Inverts homogeneous transformation matrix
266+
Invert homogeneous transformation matrix.
286267
"""
287268
Tinv = T.clone()
288269
R = T[:, :3, :3]
@@ -297,54 +278,23 @@ def inverse(self, invert_composed: bool = False):
297278
current transformation.
298279
299280
Args:
300-
invert_composed:
301-
- True: First compose the list of stored transformations
302-
and then apply inverse to the result. This is
303-
potentially slower for classes of transformations
304-
with inverses that can be computed efficiently
305-
(e.g. rotations and translations).
306-
- False: Invert the individual stored transformations
307-
independently without composing them.
281+
invert_composed: ignored, included for backwards compatibility
308282
309283
Returns:
310284
A new Transform3D object containing the inverse of the original
311285
transformation.
312286
"""
313287

314-
tinv = Transform3d(device=self.device)
288+
i_matrix = self._get_matrix_inverse()
315289

316-
if invert_composed:
317-
# first compose then invert
318-
tinv._matrix = self._invert_transformation_matrix(self.get_matrix())
319-
else:
320-
# self._get_matrix_inverse() implements efficient inverse
321-
# of self._matrix
322-
i_matrix = self._get_matrix_inverse()
323-
324-
# 2 cases:
325-
if len(self._transforms) > 0:
326-
# a) Either we have a non-empty list of transforms:
327-
# Here we take self._matrix and append its inverse at the
328-
# end of the reverted _transforms list. After composing
329-
# the transformations with get_matrix(), this correctly
330-
# right-multiplies by the inverse of self._matrix
331-
# at the end of the composition.
332-
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
333-
last = Transform3d(device=self.device)
334-
last._matrix = i_matrix
335-
tinv._transforms.append(last)
336-
else:
337-
# b) Or there are no stored transformations
338-
# we just set inverted matrix
339-
tinv._matrix = i_matrix
290+
tinv = Transform3d(matrix=i_matrix, device=self.device)
340291

341292
return tinv
342293

343294
def stack(self, *others):
344295
transforms = [self] + list(others)
345296
matrix = torch.cat([t._matrix for t in transforms], dim=0)
346-
out = Transform3d()
347-
out._matrix = matrix
297+
out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype)
348298
return out
349299

350300
def transform_points(self, points, eps: Optional[float] = None):
@@ -478,7 +428,6 @@ def clone(self):
478428
if self._lu is not None:
479429
other._lu = [elem.clone() for elem in self._lu]
480430
other._matrix = self._matrix.clone()
481-
other._transforms = [t.clone() for t in self._transforms]
482431
return other
483432

484433
def to(self, device, copy: bool = False, dtype=None):
@@ -504,7 +453,6 @@ def to(self, device, copy: bool = False, dtype=None):
504453
other.device = device
505454
other.dtype = dtype if dtype is not None else other.dtype
506455
other._matrix = self._matrix.to(device=device, dtype=dtype)
507-
other._transforms = [t.to(device, copy=copy, dtype=dtype) for t in other._transforms]
508456
return other
509457

510458
def cpu(self):

0 commit comments

Comments
 (0)