10
10
from .rotation_conversions import _axis_angle_rotation , matrix_to_quaternion , quaternion_to_matrix , \
11
11
euler_angles_to_matrix
12
12
from pytorch_kinematics .transforms .perturbation import sample_perturbations
13
+ from arm_pytorch_utilities import linalg
13
14
14
15
DEFAULT_EULER_CONVENTION = "XYZ"
15
16
@@ -297,7 +298,7 @@ def stack(self, *others):
297
298
out = Transform3d (matrix = matrix , device = self .device , dtype = self .dtype )
298
299
return out
299
300
300
- def transform_points (self , points , eps : Optional [float ] = None ):
301
+ def transform_points (self , points , eps : Optional [float ] = None , batch_to_batch = False ):
301
302
"""
302
303
Use this transform to transform a set of 3D points. Assumes row major
303
304
ordering of the input points.
@@ -311,6 +312,8 @@ def transform_points(self, points, eps: Optional[float] = None):
311
312
torch.clamp(last_coord.abs(), eps),
312
313
i.e. the last coordinates that are exactly 0 will
313
314
be clamped to +eps.
315
+ batch_to_batch: If True, then each transform is applied to the corresponding point instead of all points.
316
+ Note that this only makes sense if the number of transforms matches the number of points.
314
317
315
318
Returns:
316
319
points_out: points of shape (N, P, 3) or (P, 3) depending
@@ -328,7 +331,10 @@ def transform_points(self, points, eps: Optional[float] = None):
328
331
points_batch = torch .cat ([points_batch , ones ], dim = 2 )
329
332
330
333
composed_matrix = self .get_matrix ().transpose (- 1 , - 2 )
331
- points_out = _broadcast_bmm (points_batch , composed_matrix )
334
+ if batch_to_batch :
335
+ points_out = linalg .batch_batch_product (points_batch , composed_matrix )
336
+ else :
337
+ points_out = _broadcast_bmm (points_batch , composed_matrix )
332
338
denom = points_out [..., 3 :] # denominator
333
339
if eps is not None :
334
340
denom_sign = denom .sign () + (denom == 0.0 ).type_as (denom )
@@ -342,12 +348,14 @@ def transform_points(self, points, eps: Optional[float] = None):
342
348
343
349
return points_out
344
350
345
- def transform_normals (self , normals ):
351
+ def transform_normals (self , normals , batch_to_batch = False ):
346
352
"""
347
353
Use this transform to transform a set of normal vectors.
348
354
349
355
Args:
350
356
normals: Tensor of shape (P, 3) or (N, P, 3)
357
+ batch_to_batch: If True, then each transform is applied to the corresponding normal instead of all normals.
358
+ Note that this only makes sense if the number of transforms matches the number of normals.
351
359
352
360
Returns:
353
361
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
@@ -357,7 +365,11 @@ def transform_normals(self, normals):
357
365
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
358
366
raise ValueError (msg % (normals .shape ,))
359
367
mat = self .inverse ().get_matrix ()[:, :3 , :3 ]
360
- normals_out = _broadcast_bmm (normals , mat )
368
+
369
+ if batch_to_batch :
370
+ normals_out = linalg .batch_batch_product (normals , mat )
371
+ else :
372
+ normals_out = _broadcast_bmm (normals , mat )
361
373
362
374
# This doesn't pass unit tests. TODO investigate further
363
375
# if self._lu is None:
0 commit comments