@@ -275,7 +275,20 @@ def _get_matrix_inverse(self):
275
275
"""
276
276
Return the inverse of self._matrix.
277
277
"""
278
- return torch .inverse (self ._matrix )
278
+
279
+ return self ._invert_transformation_matrix (self ._matrix )
280
+
281
+ @staticmethod
282
+ def _invert_transformation_matrix (T ):
283
+ """
284
+ Inverts homogeneous transformation matrix
285
+ """
286
+ Tinv = T .clone ()
287
+ R = T [:, :3 , :3 ]
288
+ t = T [:, :3 , 3 ]
289
+ Tinv [:, :3 , :3 ] = R .transpose (1 , 2 )
290
+ Tinv [:, :3 , 3 :] = - Tinv [:, :3 , :3 ] @ t .unsqueeze (- 1 )
291
+ return Tinv
279
292
280
293
def inverse (self , invert_composed : bool = False ):
281
294
"""
@@ -293,15 +306,15 @@ def inverse(self, invert_composed: bool = False):
293
306
independently without composing them.
294
307
295
308
Returns:
296
- A new Transform3D object contaning the inverse of the original
309
+ A new Transform3D object containing the inverse of the original
297
310
transformation.
298
311
"""
299
312
300
313
tinv = Transform3d (device = self .device )
301
314
302
315
if invert_composed :
303
316
# first compose then invert
304
- tinv ._matrix = torch . inverse (self .get_matrix ())
317
+ tinv ._matrix = self . _invert_transformation_matrix (self .get_matrix ())
305
318
else :
306
319
# self._get_matrix_inverse() implements efficient inverse
307
320
# of self._matrix
@@ -392,10 +405,7 @@ def transform_normals(self, normals):
392
405
if normals .dim () not in [2 , 3 ]:
393
406
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
394
407
raise ValueError (msg % (normals .shape ,))
395
- composed_matrix = self .get_matrix ()
396
-
397
- # TODO: inverse is bad! Solve a linear system instead
398
- mat = composed_matrix [:, :3 , :3 ]
408
+ mat = self ._get_matrix_inverse ()[:, :3 , :3 ]
399
409
normals_out = _broadcast_bmm (normals , mat .inverse ())
400
410
401
411
# This doesn't pass unit tests. TODO investigate further
@@ -410,6 +420,31 @@ def transform_normals(self, normals):
410
420
411
421
return normals_out
412
422
423
+ def transform_shape_operator (self , shape_operators ):
424
+ """
425
+ Use this transform to transform a set of shape_operator (or Weingarten map).
426
+ This is the hessian of a signed-distance, i.e. gradient of a normal vector.
427
+
428
+ Args:
429
+ shape_operators: Tensor of shape (P, 3, 3) or (N, P, 3, 3)
430
+
431
+ Returns:
432
+ shape_operators_out: Tensor of shape (P, 3, 3) or (N, P, 3, 3) depending
433
+ on the dimensions of the transform
434
+ """
435
+ if shape_operators .dim () not in [3 , 4 ]:
436
+ msg = "Expected shape_operators to have dim = 3 or dim = 4: got shape %r"
437
+ raise ValueError (msg % (shape_operators .shape ,))
438
+ mat = self ._get_matrix_inverse ()[:, :3 , :3 ]
439
+ shape_operators_out = _broadcast_bmm (mat .permute (0 , 2 , 1 ), _broadcast_bmm (shape_operators , mat ))
440
+
441
+ # When transform is (1, 4, 4) and shape_operator is (P, 3, 3) return
442
+ # shape_operators_out of shape (P, 3, 3)
443
+ if shape_operators_out .shape [0 ] == 1 and shape_operators .dim () == 3 :
444
+ shape_operators_out = shape_operators_out .reshape (shape_operators .shape )
445
+
446
+ return shape_operators_out
447
+
413
448
def translate (self , * args , ** kwargs ):
414
449
return self .compose (Translate (device = self .device , * args , ** kwargs ))
415
450
0 commit comments