Skip to content

Commit bc9c65d

Browse files
author
Peter
committed
de-duplicate axis_angle conversion functions
Also remove dependency on transformations, since we only use one function from it, and it had compatibility issues. I couldn't get it to install on my laptop.
1 parent 81f2caf commit bc9c65d

File tree

5 files changed

+109
-38
lines changed

5 files changed

+109
-38
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ dependencies = [
4545
'numpy',
4646
'pyyaml',
4747
'torch',
48-
'transformations',
4948
]
5049

5150
[project.optional-dependencies]

src/pytorch_kinematics/chain.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import lru_cache
2-
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_angle_to_matrix
3-
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_d_to_pris_matrix
2+
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix
3+
from pytorch_kinematics.transforms.rotation_conversions import axis_and_d_to_pris_matrix
44
from typing import Optional, Sequence
55

66
import numpy as np
@@ -294,8 +294,8 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
294294
# compute all joint transforms at once first
295295
# in order to handle multiple joint types without branching, we create all possible transforms
296296
# for all joint types and then select the appropriate one for each joint.
297-
rev_jnt_transform = tensor_axis_and_angle_to_matrix(axes_expanded, th)
298-
pris_jnt_transform = tensor_axis_and_d_to_pris_matrix(axes_expanded, th)
297+
rev_jnt_transform = axis_and_angle_to_matrix(axes_expanded, th)
298+
pris_jnt_transform = axis_and_d_to_pris_matrix(axes_expanded, th)
299299

300300
frame_transforms = {}
301301
b = th.shape[0]
@@ -456,7 +456,7 @@ def jacobian(self, th, locations=None):
456456

457457
def forward_kinematics(self, th, end_only: bool = True):
458458
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
459-
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(end_only, th)
459+
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(th, end_only)
460460

461461
mat = super().forward_kinematics(th, frame_indices)
462462

@@ -465,7 +465,7 @@ def forward_kinematics(self, th, end_only: bool = True):
465465
else:
466466
return mat
467467

468-
def convert_serial_inputs_to_chain_inputs(self, end_only, th):
468+
def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool):
469469
if end_only:
470470
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
471471
else:

src/pytorch_kinematics/transforms/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from pytorch_kinematics.transforms.perturbation import sample_perturbations
34
from .rotation_conversions import (
4-
axis_and_angle_to_matrix,
55
axis_angle_to_quaternion,
66
euler_angles_to_matrix,
77
matrix_to_axis_angle,
@@ -13,13 +13,14 @@
1313
quaternion_multiply,
1414
quaternion_raw_multiply,
1515
quaternion_to_matrix,
16+
quaternion_from_euler,
1617
random_quaternions,
1718
random_rotation,
1819
random_rotations,
1920
rotation_6d_to_matrix,
2021
standardize_quaternion,
21-
tensor_axis_and_angle_to_matrix,
22-
tensor_axis_and_d_to_pris_matrix,
22+
axis_and_angle_to_matrix,
23+
axis_and_d_to_pris_matrix,
2324
wxyz_to_xyzw,
2425
xyzw_to_wxyz,
2526
)
@@ -30,6 +31,5 @@
3031
so3_rotation_angle,
3132
)
3233
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
33-
from pytorch_kinematics.transforms.perturbation import sample_perturbations
3434

3535
__all__ = [k for k in globals().keys() if not k.startswith("_")]

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
import functools
4+
import math
45
from typing import Optional
56
from warnings import warn
67

8+
import numpy
79
import torch
810
import torch.nn.functional as F
911

@@ -450,7 +452,7 @@ def quaternion_apply(quaternion, point):
450452
return out[..., 1:]
451453

452454

453-
def tensor_axis_and_d_to_pris_matrix(axis, d):
455+
def axis_and_d_to_pris_matrix(axis, d):
454456
"""
455457
Creates a 4x4 matrix that represents a translation along an axis of a distance d
456458
Works with any number of batch dimensions.
@@ -470,7 +472,7 @@ def tensor_axis_and_d_to_pris_matrix(axis, d):
470472
return mat44
471473

472474

473-
def tensor_axis_and_angle_to_matrix(axis, theta):
475+
def axis_and_angle_to_matrix(axis, theta):
474476
"""
475477
Creates a 4x4 matrix that represents a rotation around an axis by an angle theta.
476478
Works with any number of batch dimensions.
@@ -505,27 +507,6 @@ def tensor_axis_and_angle_to_matrix(axis, theta):
505507
return mat44
506508

507509

508-
def axis_and_angle_to_matrix(axis, theta):
509-
# based on https://ai.stackexchange.com/questions/14041/, and checked against wikipedia
510-
c = torch.cos(theta) # NOTE: cos is not that precise for float32, you may want to use float64
511-
one_minus_c = 1 - c
512-
s = torch.sin(theta)
513-
kx, ky, kz = torch.unbind(axis, -1)
514-
r00 = c + kx * kx * one_minus_c
515-
r01 = kx * ky * one_minus_c - kz * s
516-
r02 = kx * kz * one_minus_c + ky * s
517-
r10 = ky * kx * one_minus_c + kz * s
518-
r11 = c + ky * ky * one_minus_c
519-
r12 = ky * kz * one_minus_c - kx * s
520-
r20 = kz * kx * one_minus_c - ky * s
521-
r21 = kz * ky * one_minus_c + kx * s
522-
r22 = c + kz * kz * one_minus_c
523-
rot = torch.stack([torch.cat([r00, r01, r02], -1),
524-
torch.cat([r10, r11, r12], -1),
525-
torch.cat([r20, r21, r22], -1)], -2)
526-
return rot
527-
528-
529510
def axis_angle_to_matrix(axis_angle):
530511
"""
531512
Convert rotations given as axis/angle to rotation matrices.
@@ -541,7 +522,7 @@ def axis_angle_to_matrix(axis_angle):
541522
Returns:
542523
Rotation matrices as tensor of shape (..., 3, 3).
543524
"""
544-
warn('This is deprecated because it is slow. Use axis_and_angle_to_matrix or zpk_cpp.axis_and_angle_to_matrix',
525+
warn('This is deprecated because it is slow. Use axis_and_angle_to_matrix',
545526
DeprecationWarning, stacklevel=2)
546527
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
547528

@@ -682,3 +663,96 @@ def pos_rot_to_matrix(pos, rot):
682663
m[..., :3, 3] = pos
683664
m[..., :3, :3] = rot
684665
return m
666+
667+
668+
# axis sequences for Euler angles
669+
_NEXT_AXIS = [1, 2, 0, 1]
670+
671+
# map axes strings to/from tuples of inner axis, parity, repetition, frame
672+
_AXES2TUPLE = {
673+
'sxyz': (0, 0, 0, 0),
674+
'sxyx': (0, 0, 1, 0),
675+
'sxzy': (0, 1, 0, 0),
676+
'sxzx': (0, 1, 1, 0),
677+
'syzx': (1, 0, 0, 0),
678+
'syzy': (1, 0, 1, 0),
679+
'syxz': (1, 1, 0, 0),
680+
'syxy': (1, 1, 1, 0),
681+
'szxy': (2, 0, 0, 0),
682+
'szxz': (2, 0, 1, 0),
683+
'szyx': (2, 1, 0, 0),
684+
'szyz': (2, 1, 1, 0),
685+
'rzyx': (0, 0, 0, 1),
686+
'rxyx': (0, 0, 1, 1),
687+
'ryzx': (0, 1, 0, 1),
688+
'rxzx': (0, 1, 1, 1),
689+
'rxzy': (1, 0, 0, 1),
690+
'ryzy': (1, 0, 1, 1),
691+
'rzxy': (1, 1, 0, 1),
692+
'ryxy': (1, 1, 1, 1),
693+
'ryxz': (2, 0, 0, 1),
694+
'rzxz': (2, 0, 1, 1),
695+
'rxyz': (2, 1, 0, 1),
696+
'rzyz': (2, 1, 1, 1),
697+
}
698+
699+
_TUPLE2AXES = {v: k for k, v in _AXES2TUPLE.items()}
700+
701+
702+
def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
703+
"""
704+
Return quaternion from Euler angles and axis sequence.
705+
Taken from https://github.com/cgohlke/transformations/blob/master/transformations/transformations.py#L1238
706+
707+
ai, aj, ak : Euler's roll, pitch and yaw angles
708+
axes : One of 24 axis sequences as string or encoded tuple
709+
710+
>>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
711+
>>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
712+
True
713+
714+
"""
715+
try:
716+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
717+
except (AttributeError, KeyError):
718+
_TUPLE2AXES[axes] # noqa: validation
719+
firstaxis, parity, repetition, frame = axes
720+
721+
i = firstaxis + 1
722+
j = _NEXT_AXIS[i + parity - 1] + 1
723+
k = _NEXT_AXIS[i - parity] + 1
724+
725+
if frame:
726+
ai, ak = ak, ai
727+
if parity:
728+
aj = -aj
729+
730+
ai /= 2.0
731+
aj /= 2.0
732+
ak /= 2.0
733+
ci = math.cos(ai)
734+
si = math.sin(ai)
735+
cj = math.cos(aj)
736+
sj = math.sin(aj)
737+
ck = math.cos(ak)
738+
sk = math.sin(ak)
739+
cc = ci * ck
740+
cs = ci * sk
741+
sc = si * ck
742+
ss = si * sk
743+
744+
q = numpy.empty((4,))
745+
if repetition:
746+
q[0] = cj * (cc - ss)
747+
q[i] = cj * (cs + sc)
748+
q[j] = sj * (cc + ss)
749+
q[k] = sj * (cs - sc)
750+
else:
751+
q[0] = cj * cc + sj * ss
752+
q[i] = cj * sc - sj * cs
753+
q[j] = cj * ss + sj * cc
754+
q[k] = cj * cs - sj * sc
755+
if parity:
756+
q[j] *= -1.0
757+
758+
return q

src/pytorch_kinematics/urdf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from . import chain
44
import torch
55
import pytorch_kinematics.transforms as tf
6-
# has better RPY to quaternion transformation
7-
import transformations as tf2
86

97
JOINT_TYPE_MAP = {'revolute': 'revolute',
108
'continuous': 'revolute',
@@ -16,7 +14,7 @@ def _convert_transform(origin):
1614
if origin is None:
1715
return tf.Transform3d()
1816
else:
19-
return tf.Transform3d(rot=torch.tensor(tf2.quaternion_from_euler(*origin.rpy, "sxyz"), dtype=torch.float32),
17+
return tf.Transform3d(rot=torch.tensor(tf.quaternion_from_euler(*origin.rpy, "sxyz"), dtype=torch.float32),
2018
pos=origin.xyz)
2119

2220

0 commit comments

Comments
 (0)