Skip to content

Commit 652642f

Browse files
committed
change Transform3d.inverse to more efficient inverse. Add Transform3d.transform_shape_operator function
1 parent 6d1fd2d commit 652642f

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,20 @@ def _get_matrix_inverse(self):
275275
"""
276276
Return the inverse of self._matrix.
277277
"""
278-
return torch.inverse(self._matrix)
278+
279+
return self._invert_transformation_matrix(self._matrix)
280+
281+
@staticmethod
282+
def _invert_transformation_matrix(T):
283+
"""
284+
Inverts homogeneous transformation matrix
285+
"""
286+
Tinv = T.clone()
287+
R = T[:, :3, :3]
288+
t = T[:, :3, 3]
289+
Tinv[:, :3, :3] = R.transpose(1, 2)
290+
Tinv[:, :3, 3:] = -Tinv[:, :3, :3] @ t.unsqueeze(-1)
291+
return Tinv
279292

280293
def inverse(self, invert_composed: bool = False):
281294
"""
@@ -293,15 +306,15 @@ def inverse(self, invert_composed: bool = False):
293306
independently without composing them.
294307
295308
Returns:
296-
A new Transform3D object contaning the inverse of the original
309+
A new Transform3D object containing the inverse of the original
297310
transformation.
298311
"""
299312

300313
tinv = Transform3d(device=self.device)
301314

302315
if invert_composed:
303316
# first compose then invert
304-
tinv._matrix = torch.inverse(self.get_matrix())
317+
tinv._matrix = self._invert_transformation_matrix(self.get_matrix())
305318
else:
306319
# self._get_matrix_inverse() implements efficient inverse
307320
# of self._matrix
@@ -392,10 +405,7 @@ def transform_normals(self, normals):
392405
if normals.dim() not in [2, 3]:
393406
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
394407
raise ValueError(msg % (normals.shape,))
395-
composed_matrix = self.get_matrix()
396-
397-
# TODO: inverse is bad! Solve a linear system instead
398-
mat = composed_matrix[:, :3, :3]
408+
mat = self._get_matrix_inverse()[:, :3, :3]
399409
normals_out = _broadcast_bmm(normals, mat.inverse())
400410

401411
# This doesn't pass unit tests. TODO investigate further
@@ -410,6 +420,31 @@ def transform_normals(self, normals):
410420

411421
return normals_out
412422

423+
def transform_shape_operator(self, shape_operators):
424+
"""
425+
Use this transform to transform a set of shape_operator (or Weingarten map).
426+
This is the hessian of a signed-distance, i.e. gradient of a normal vector.
427+
428+
Args:
429+
shape_operators: Tensor of shape (P, 3, 3) or (N, P, 3, 3)
430+
431+
Returns:
432+
shape_operators_out: Tensor of shape (P, 3, 3) or (N, P, 3, 3) depending
433+
on the dimensions of the transform
434+
"""
435+
if shape_operators.dim() not in [3, 4]:
436+
msg = "Expected shape_operators to have dim = 3 or dim = 4: got shape %r"
437+
raise ValueError(msg % (shape_operators.shape,))
438+
mat = self._get_matrix_inverse()[:, :3, :3]
439+
shape_operators_out = _broadcast_bmm(mat.permute(0, 2, 1), _broadcast_bmm(shape_operators, mat))
440+
441+
# When transform is (1, 4, 4) and shape_operator is (P, 3, 3) return
442+
# shape_operators_out of shape (P, 3, 3)
443+
if shape_operators_out.shape[0] == 1 and shape_operators.dim() == 3:
444+
shape_operators_out = shape_operators_out.reshape(shape_operators.shape)
445+
446+
return shape_operators_out
447+
413448
def translate(self, *args, **kwargs):
414449
return self.compose(Translate(device=self.device, *args, **kwargs))
415450

0 commit comments

Comments
 (0)