Skip to content

Commit dd7a3fb

Browse files
author
Peter
committed
change how we test for quaternion equality
1 parent 3ff7f3c commit dd7a3fb

File tree

5 files changed

+61
-52
lines changed

5 files changed

+61
-52
lines changed

src/pytorch_kinematics/transforms/math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010
import torch
1111

1212

13+
def quaternion_close(q1: torch.Tensor, q2: torch.Tensor, eps: float = 1e-4):
14+
"""
15+
Returns true if two quaternions are close to each other. Assumes the quaternions are normalized.
16+
Based on: https://math.stackexchange.com/a/90098/516340
17+
18+
"""
19+
dist = 1 - torch.square(torch.sum(q1*q2, dim=-1))
20+
return torch.all(dist < eps)
21+
22+
1323
def acos_linear_extrapolation(
1424
x: torch.Tensor,
1525
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
import functools
4-
import math
54
from typing import Optional
65
from warnings import warn
76

8-
import numpy
97
import torch
108
import torch.nn.functional as F
119

@@ -715,25 +713,22 @@ def pos_rot_to_matrix(pos, rot):
715713
_TUPLE2AXES = {v: k for k, v in _AXES2TUPLE.items()}
716714

717715

718-
def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
716+
def quaternion_from_euler(rpy, axes='sxyz'):
719717
"""
720718
Return quaternion from Euler angles and axis sequence.
721719
Taken from https://github.com/cgohlke/transformations/blob/master/transformations/transformations.py#L1238
722720
723721
ai, aj, ak : Euler's roll, pitch and yaw angles
724722
axes : One of 24 axis sequences as string or encoded tuple
725723
726-
>>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
727-
>>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
728-
True
729-
730724
"""
731725
try:
732726
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
733727
except (AttributeError, KeyError):
734728
_TUPLE2AXES[axes] # noqa: validation
735729
firstaxis, parity, repetition, frame = axes
736730

731+
ai, aj, ak = torch.unbind(rpy, -1)
737732
i = firstaxis + 1
738733
j = _NEXT_AXIS[i + parity - 1] + 1
739734
k = _NEXT_AXIS[i - parity] + 1
@@ -746,29 +741,29 @@ def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
746741
ai /= 2.0
747742
aj /= 2.0
748743
ak /= 2.0
749-
ci = math.cos(ai)
750-
si = math.sin(ai)
751-
cj = math.cos(aj)
752-
sj = math.sin(aj)
753-
ck = math.cos(ak)
754-
sk = math.sin(ak)
744+
ci = torch.cos(ai)
745+
si = torch.sin(ai)
746+
cj = torch.cos(aj)
747+
sj = torch.sin(aj)
748+
ck = torch.cos(ak)
749+
sk = torch.sin(ak)
755750
cc = ci * ck
756751
cs = ci * sk
757752
sc = si * ck
758753
ss = si * sk
759754

760-
q = numpy.empty((4,))
755+
q = torch.zeros([*rpy.shape[:-1], 4]).to(rpy)
761756
if repetition:
762-
q[0] = cj * (cc - ss)
763-
q[i] = cj * (cs + sc)
764-
q[j] = sj * (cc + ss)
765-
q[k] = sj * (cs - sc)
757+
q[..., 0] = cj * (cc - ss)
758+
q[..., i] = cj * (cs + sc)
759+
q[..., j] = sj * (cc + ss)
760+
q[..., k] = sj * (cs - sc)
766761
else:
767-
q[0] = cj * cc + sj * ss
768-
q[i] = cj * sc - sj * cs
769-
q[j] = cj * ss + sj * cc
770-
q[k] = cj * cs - sj * sc
762+
q[..., 0] = cj * cc + sj * ss
763+
q[..., i] = cj * sc - sj * cs
764+
q[..., j] = cj * ss + sj * cc
765+
q[..., k] = cj * cs - sj * sc
771766
if parity:
772-
q[j] *= -1.0
767+
q[..., j] *= -1.0
773768

774769
return q

src/pytorch_kinematics/urdf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def _convert_transform(origin):
1414
if origin is None:
1515
return tf.Transform3d()
1616
else:
17-
return tf.Transform3d(rot=torch.tensor(tf.quaternion_from_euler(*origin.rpy, "sxyz"), dtype=torch.float32),
18-
pos=origin.xyz)
17+
rpy = torch.tensor(origin.rpy, dtype=torch.float32)
18+
return tf.Transform3d(rot=tf.quaternion_from_euler(rpy, "sxyz"), pos=origin.xyz)
1919

2020

2121
def _convert_visual(visual):

tests/test_kinematics.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import pytorch_kinematics as pk
8+
from pytorch_kinematics.transforms.math import quaternion_close
89

910
TEST_DIR = os.path.dirname(__file__)
1011

@@ -16,11 +17,6 @@ def quat_pos_from_transform3d(tg):
1617
return pos, rot
1718

1819

19-
def quaternion_equality(a, b, rtol=1e-5):
20-
# negative of a quaternion is the same rotation
21-
return torch.allclose(a, b, rtol=rtol) or torch.allclose(a, -b, rtol=rtol)
22-
23-
2420
# test more complex robot and the MJCF parser
2521
def test_fk_mjcf():
2622
chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "ant.xml")).read())
@@ -33,11 +29,11 @@ def test_fk_mjcf():
3329
ret = chain.forward_kinematics(th)
3430
tg = ret['aux_1']
3531
pos, rot = quat_pos_from_transform3d(tg)
36-
assert quaternion_equality(rot, torch.tensor([0.87758256, 0., 0., 0.47942554], dtype=torch.float64))
32+
assert quaternion_close(rot, torch.tensor([0.87758256, 0., 0., 0.47942554], dtype=torch.float64))
3733
assert torch.allclose(pos, torch.tensor([0.2, 0.2, 0.75], dtype=torch.float64))
3834
tg = ret['front_left_foot']
3935
pos, rot = quat_pos_from_transform3d(tg)
40-
assert quaternion_equality(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
36+
assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
4137
assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64))
4238
print(ret)
4339

@@ -47,7 +43,7 @@ def test_fk_serial_mjcf():
4743
chain = chain.to(dtype=torch.float64)
4844
tg = chain.forward_kinematics([1.0, 1.0])
4945
pos, rot = quat_pos_from_transform3d(tg)
50-
assert quaternion_equality(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
46+
assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
5147
assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64))
5248

5349

@@ -72,7 +68,7 @@ def test_fkik():
7268
tg = chain.forward_kinematics(th1)
7369
pos, rot = quat_pos_from_transform3d(tg)
7470
assert torch.allclose(pos, torch.tensor([[1.91081784, 0.41280851, 0.0000]]))
75-
assert quaternion_equality(rot, torch.tensor([[0.95521418, 0.0000, 0.0000, 0.2959153]]))
71+
assert quaternion_close(rot, torch.tensor([[0.95521418, 0.0000, 0.0000, 0.2959153]]))
7672
N = 20
7773
th_batch = torch.rand(N, 2)
7874
tg_batch = chain.forward_kinematics(th_batch)
@@ -98,22 +94,20 @@ def test_urdf():
9894
ret = chain.forward_kinematics(th)
9995
tg = ret['lbr_iiwa_link_7']
10096
pos, rot = quat_pos_from_transform3d(tg)
101-
assert quaternion_equality(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
102-
assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64))
97+
assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
98+
assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6)
10399

104100

105101
def test_urdf_serial():
106102
chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
107103
chain.to(dtype=torch.float64)
108-
print(chain)
109-
print(chain.get_joint_parameter_names())
110104
th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]
111105

112106
ret = chain.forward_kinematics(th, end_only=False)
113107
tg = ret['lbr_iiwa_link_7']
114108
pos, rot = quat_pos_from_transform3d(tg)
115-
assert quaternion_equality(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
116-
assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64))
109+
assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
110+
assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6)
117111

118112
N = 1000
119113
d = "cuda" if torch.cuda.is_available() else "cpu"
@@ -162,7 +156,7 @@ def test_fk_simple_arm():
162156
})
163157
tg = ret['arm_wrist_roll']
164158
pos, rot = quat_pos_from_transform3d(tg)
165-
assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
159+
assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
166160
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
167161

168162
N = 100
@@ -176,7 +170,7 @@ def test_sdf_serial_chain():
176170
chain = chain.to(dtype=torch.float64)
177171
tg = chain.forward_kinematics([0., math.pi / 2.0, -0.5, 0.])
178172
pos, rot = quat_pos_from_transform3d(tg)
179-
assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
173+
assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
180174
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
181175

182176

@@ -201,7 +195,7 @@ def test_cuda():
201195
})
202196
tg = ret['arm_wrist_roll']
203197
pos, rot = quat_pos_from_transform3d(tg)
204-
assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d))
198+
assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d))
205199
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=dtype, device=d))
206200

207201
data = '<robot name="test_robot">' \
@@ -256,7 +250,7 @@ def test_fk_val():
256250
tg = ret['drive45']
257251
pos, rot = quat_pos_from_transform3d(tg)
258252
torch.set_printoptions(precision=6, sci_mode=False)
259-
assert quaternion_equality(rot, torch.tensor([0.5, 0.5, -0.5, 0.5], dtype=torch.float64), rtol=1e-4)
253+
assert quaternion_close(rot, torch.tensor([0.5, 0.5, -0.5, 0.5], dtype=torch.float64))
260254
assert torch.allclose(pos, torch.tensor([-0.225692, 0.259045, 0.262139], dtype=torch.float64))
261255

262256

tests/test_rotation_conversions.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55

6+
from pytorch_kinematics.transforms.math import quaternion_close
67
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33, axis_angle_to_matrix, \
78
pos_rot_to_matrix, matrix_to_pos_rot, random_rotations, quaternion_from_euler
89

@@ -24,18 +25,27 @@ def test_axis_angle_to_matrix_perf():
2425

2526

2627
def test_quaternion_from_euler():
27-
q = quaternion_from_euler(0, 0, 0)
28-
np.testing.assert_allclose(q, np.array([1., 0, 0, 0]))
28+
q = quaternion_from_euler(torch.tensor([0., 0, 0]))
29+
assert quaternion_close(q, torch.tensor([1., 0, 0, 0]))
2930
root2_over_2 = np.sqrt(2) / 2
3031

31-
q = quaternion_from_euler(0, 0, np.pi / 2)
32-
np.testing.assert_allclose(q, np.array([root2_over_2, 0, 0, root2_over_2]))
32+
q = quaternion_from_euler(torch.tensor([0, 0, np.pi / 2]))
33+
assert quaternion_close(q, torch.tensor([root2_over_2, 0, 0, root2_over_2], dtype=q.dtype))
3334

34-
q = quaternion_from_euler(-np.pi / 2, 0, 0)
35-
np.testing.assert_allclose(q, np.array([root2_over_2, -root2_over_2, 0, 0]))
35+
q = quaternion_from_euler(torch.tensor([-np.pi / 2, 0, 0]))
36+
assert quaternion_close(q, torch.tensor([root2_over_2, -root2_over_2, 0, 0], dtype=q.dtype))
3637

37-
q = quaternion_from_euler(0, np.pi / 2, 0)
38-
np.testing.assert_allclose(q, np.array([root2_over_2, 0, root2_over_2, 0]))
38+
q = quaternion_from_euler(torch.tensor([0, np.pi / 2, 0]))
39+
assert quaternion_close(q, torch.tensor([root2_over_2, 0, root2_over_2, 0], dtype=q.dtype))
40+
41+
# Test batched
42+
b = 32
43+
rpy = torch.tensor([0, np.pi / 2, 0])
44+
rpy_batch = torch.tile(rpy[None], (b, 1))
45+
q_batch = quaternion_from_euler(rpy_batch)
46+
q_expected = torch.tensor([root2_over_2, 0, root2_over_2, 0], dtype=q.dtype)
47+
q_expected_batch = torch.tile(q_expected[None], (b, 1))
48+
assert quaternion_close(q_batch, q_expected_batch)
3949

4050

4151
def test_pos_rot_conversion():

0 commit comments

Comments
 (0)