Skip to content

Commit caf2df3

Browse files
committed
Add batch to batch transforms
1 parent 11de9ad commit caf2df3

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .rotation_conversions import _axis_angle_rotation, matrix_to_quaternion, quaternion_to_matrix, \
1111
euler_angles_to_matrix
1212
from pytorch_kinematics.transforms.perturbation import sample_perturbations
13+
from arm_pytorch_utilities import linalg
1314

1415
DEFAULT_EULER_CONVENTION = "XYZ"
1516

@@ -297,7 +298,7 @@ def stack(self, *others):
297298
out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype)
298299
return out
299300

300-
def transform_points(self, points, eps: Optional[float] = None):
301+
def transform_points(self, points, eps: Optional[float] = None, batch_to_batch=False):
301302
"""
302303
Use this transform to transform a set of 3D points. Assumes row major
303304
ordering of the input points.
@@ -311,6 +312,8 @@ def transform_points(self, points, eps: Optional[float] = None):
311312
torch.clamp(last_coord.abs(), eps),
312313
i.e. the last coordinates that are exactly 0 will
313314
be clamped to +eps.
315+
batch_to_batch: If True, then each transform is applied to the corresponding point instead of all points.
316+
Note that this only makes sense if the number of transforms matches the number of points.
314317
315318
Returns:
316319
points_out: points of shape (N, P, 3) or (P, 3) depending
@@ -328,7 +331,10 @@ def transform_points(self, points, eps: Optional[float] = None):
328331
points_batch = torch.cat([points_batch, ones], dim=2)
329332

330333
composed_matrix = self.get_matrix().transpose(-1, -2)
331-
points_out = _broadcast_bmm(points_batch, composed_matrix)
334+
if batch_to_batch:
335+
points_out = linalg.batch_batch_product(points_batch, composed_matrix)
336+
else:
337+
points_out = _broadcast_bmm(points_batch, composed_matrix)
332338
denom = points_out[..., 3:] # denominator
333339
if eps is not None:
334340
denom_sign = denom.sign() + (denom == 0.0).type_as(denom)
@@ -342,12 +348,14 @@ def transform_points(self, points, eps: Optional[float] = None):
342348

343349
return points_out
344350

345-
def transform_normals(self, normals):
351+
def transform_normals(self, normals, batch_to_batch=False):
346352
"""
347353
Use this transform to transform a set of normal vectors.
348354
349355
Args:
350356
normals: Tensor of shape (P, 3) or (N, P, 3)
357+
batch_to_batch: If True, then each transform is applied to the corresponding normal instead of all normals.
358+
Note that this only makes sense if the number of transforms matches the number of normals.
351359
352360
Returns:
353361
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
@@ -357,7 +365,11 @@ def transform_normals(self, normals):
357365
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
358366
raise ValueError(msg % (normals.shape,))
359367
mat = self.inverse().get_matrix()[:, :3, :3]
360-
normals_out = _broadcast_bmm(normals, mat)
368+
369+
if batch_to_batch:
370+
normals_out = linalg.batch_batch_product(normals, mat)
371+
else:
372+
normals_out = _broadcast_bmm(normals, mat)
361373

362374
# This doesn't pass unit tests. TODO investigate further
363375
# if self._lu is None:

0 commit comments

Comments
 (0)