Skip to content

Commit 0fe095c

Browse files
committed
Sample axis-angle and translation perturbations around transforms
1 parent eec98b0 commit 0fe095c

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

src/pytorch_kinematics/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@
2929
so3_rotation_angle,
3030
)
3131
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
32+
from pytorch_kinematics.transforms.perturbation import sample_perturbations
3233

3334
__all__ = [k for k in globals().keys() if not k.startswith("_")]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
from pytorch_kinematics.transforms.rotation_conversions import axis_angle_to_matrix
3+
4+
5+
def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
6+
"""
7+
Sample perturbations around the given transform. The translation and rotation are sampled independently from
8+
0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its
9+
magnitude is sampled from a gaussian.
10+
:param T: given transform to perturb around
11+
:param num_perturbations: number of perturbations to sample
12+
:param radian_sigma: standard deviation of the gaussian angular perturbation in radians
13+
:param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units
14+
:return: perturbed transforms; may not include the original transform
15+
"""
16+
dtype = T.dtype
17+
device = T.device
18+
perturbed = torch.eye(4, dtype=dtype, device=device).repeat(num_perturbations, 1, 1)
19+
20+
delta_R = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * radian_sigma
21+
delta_R = axis_angle_to_matrix(delta_R)
22+
perturbed[:, :3, :3] = delta_R @ T[..., :3, :3]
23+
perturbed[:, :3, 3] = T[..., :3, 3]
24+
25+
delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
26+
perturbed[:, :3, 3] += delta_t
27+
28+
return perturbed

src/pytorch_kinematics/transforms/transform3d.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .rotation_conversions import _axis_angle_rotation, matrix_to_quaternion, quaternion_to_matrix, \
1111
euler_angles_to_matrix
12+
from pytorch_kinematics.transforms.perturbation import sample_perturbations
1213

1314
DEFAULT_EULER_CONVENTION = "XYZ"
1415

@@ -421,6 +422,14 @@ def rotate(self, *args, **kwargs):
421422
def rotate_axis_angle(self, *args, **kwargs):
422423
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
423424

425+
def sample_perturbations(self, num_perturbations, radian_sigma, translation_sigma):
426+
mat = self.get_matrix()
427+
if mat.shape[0] == 1:
428+
mat = mat[0]
429+
all_mats = sample_perturbations(mat, num_perturbations, radian_sigma, translation_sigma)
430+
out = Transform3d(matrix=all_mats)
431+
return out
432+
424433
def clone(self):
425434
"""
426435
Deep copy of Transforms object. All internal tensors are cloned

0 commit comments

Comments
 (0)