5
5
import torch
6
6
7
7
import pytorch_kinematics as pk
8
+ from pytorch_kinematics .transforms .math import quaternion_close
8
9
9
10
TEST_DIR = os .path .dirname (__file__ )
10
11
@@ -16,11 +17,6 @@ def quat_pos_from_transform3d(tg):
16
17
return pos , rot
17
18
18
19
19
- def quaternion_equality (a , b , rtol = 1e-5 ):
20
- # negative of a quaternion is the same rotation
21
- return torch .allclose (a , b , rtol = rtol ) or torch .allclose (a , - b , rtol = rtol )
22
-
23
-
24
20
# test more complex robot and the MJCF parser
25
21
def test_fk_mjcf ():
26
22
chain = pk .build_chain_from_mjcf (open (os .path .join (TEST_DIR , "ant.xml" )).read ())
@@ -33,11 +29,11 @@ def test_fk_mjcf():
33
29
ret = chain .forward_kinematics (th )
34
30
tg = ret ['aux_1' ]
35
31
pos , rot = quat_pos_from_transform3d (tg )
36
- assert quaternion_equality (rot , torch .tensor ([0.87758256 , 0. , 0. , 0.47942554 ], dtype = torch .float64 ))
32
+ assert quaternion_close (rot , torch .tensor ([0.87758256 , 0. , 0. , 0.47942554 ], dtype = torch .float64 ))
37
33
assert torch .allclose (pos , torch .tensor ([0.2 , 0.2 , 0.75 ], dtype = torch .float64 ))
38
34
tg = ret ['front_left_foot' ]
39
35
pos , rot = quat_pos_from_transform3d (tg )
40
- assert quaternion_equality (rot , torch .tensor ([0.77015115 , - 0.4600326 , 0.13497724 , 0.42073549 ], dtype = torch .float64 ))
36
+ assert quaternion_close (rot , torch .tensor ([0.77015115 , - 0.4600326 , 0.13497724 , 0.42073549 ], dtype = torch .float64 ))
41
37
assert torch .allclose (pos , torch .tensor ([0.13976626 , 0.47635466 , 0.75 ], dtype = torch .float64 ))
42
38
print (ret )
43
39
@@ -47,7 +43,7 @@ def test_fk_serial_mjcf():
47
43
chain = chain .to (dtype = torch .float64 )
48
44
tg = chain .forward_kinematics ([1.0 , 1.0 ])
49
45
pos , rot = quat_pos_from_transform3d (tg )
50
- assert quaternion_equality (rot , torch .tensor ([0.77015115 , - 0.4600326 , 0.13497724 , 0.42073549 ], dtype = torch .float64 ))
46
+ assert quaternion_close (rot , torch .tensor ([0.77015115 , - 0.4600326 , 0.13497724 , 0.42073549 ], dtype = torch .float64 ))
51
47
assert torch .allclose (pos , torch .tensor ([0.13976626 , 0.47635466 , 0.75 ], dtype = torch .float64 ))
52
48
53
49
@@ -72,7 +68,7 @@ def test_fkik():
72
68
tg = chain .forward_kinematics (th1 )
73
69
pos , rot = quat_pos_from_transform3d (tg )
74
70
assert torch .allclose (pos , torch .tensor ([[1.91081784 , 0.41280851 , 0.0000 ]]))
75
- assert quaternion_equality (rot , torch .tensor ([[0.95521418 , 0.0000 , 0.0000 , 0.2959153 ]]))
71
+ assert quaternion_close (rot , torch .tensor ([[0.95521418 , 0.0000 , 0.0000 , 0.2959153 ]]))
76
72
N = 20
77
73
th_batch = torch .rand (N , 2 )
78
74
tg_batch = chain .forward_kinematics (th_batch )
@@ -98,22 +94,20 @@ def test_urdf():
98
94
ret = chain .forward_kinematics (th )
99
95
tg = ret ['lbr_iiwa_link_7' ]
100
96
pos , rot = quat_pos_from_transform3d (tg )
101
- assert quaternion_equality (rot , torch .tensor ([7.07106781e-01 , 0 , - 7.07106781e-01 , 0 ], dtype = torch .float64 ))
102
- assert torch .allclose (pos , torch .tensor ([- 6.60827561e-01 , 0 , 3.74142136e-01 ], dtype = torch .float64 ))
97
+ assert quaternion_close (rot , torch .tensor ([7.07106781e-01 , 0 , - 7.07106781e-01 , 0 ], dtype = torch .float64 ))
98
+ assert torch .allclose (pos , torch .tensor ([- 6.60827561e-01 , 0 , 3.74142136e-01 ], dtype = torch .float64 ), atol = 1e-6 )
103
99
104
100
105
101
def test_urdf_serial ():
106
102
chain = pk .build_serial_chain_from_urdf (open (os .path .join (TEST_DIR , "kuka_iiwa.urdf" )).read (), "lbr_iiwa_link_7" )
107
103
chain .to (dtype = torch .float64 )
108
- print (chain )
109
- print (chain .get_joint_parameter_names ())
110
104
th = [0.0 , - math .pi / 4.0 , 0.0 , math .pi / 2.0 , 0.0 , math .pi / 4.0 , 0.0 ]
111
105
112
106
ret = chain .forward_kinematics (th , end_only = False )
113
107
tg = ret ['lbr_iiwa_link_7' ]
114
108
pos , rot = quat_pos_from_transform3d (tg )
115
- assert quaternion_equality (rot , torch .tensor ([7.07106781e-01 , 0 , - 7.07106781e-01 , 0 ], dtype = torch .float64 ))
116
- assert torch .allclose (pos , torch .tensor ([- 6.60827561e-01 , 0 , 3.74142136e-01 ], dtype = torch .float64 ))
109
+ assert quaternion_close (rot , torch .tensor ([7.07106781e-01 , 0 , - 7.07106781e-01 , 0 ], dtype = torch .float64 ))
110
+ assert torch .allclose (pos , torch .tensor ([- 6.60827561e-01 , 0 , 3.74142136e-01 ], dtype = torch .float64 ), atol = 1e-6 )
117
111
118
112
N = 1000
119
113
d = "cuda" if torch .cuda .is_available () else "cpu"
@@ -162,7 +156,7 @@ def test_fk_simple_arm():
162
156
})
163
157
tg = ret ['arm_wrist_roll' ]
164
158
pos , rot = quat_pos_from_transform3d (tg )
165
- assert quaternion_equality (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = torch .float64 ))
159
+ assert quaternion_close (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = torch .float64 ))
166
160
assert torch .allclose (pos , torch .tensor ([1.05 , 0.55 , 0.5 ], dtype = torch .float64 ))
167
161
168
162
N = 100
@@ -176,7 +170,7 @@ def test_sdf_serial_chain():
176
170
chain = chain .to (dtype = torch .float64 )
177
171
tg = chain .forward_kinematics ([0. , math .pi / 2.0 , - 0.5 , 0. ])
178
172
pos , rot = quat_pos_from_transform3d (tg )
179
- assert quaternion_equality (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = torch .float64 ))
173
+ assert quaternion_close (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = torch .float64 ))
180
174
assert torch .allclose (pos , torch .tensor ([1.05 , 0.55 , 0.5 ], dtype = torch .float64 ))
181
175
182
176
@@ -201,7 +195,7 @@ def test_cuda():
201
195
})
202
196
tg = ret ['arm_wrist_roll' ]
203
197
pos , rot = quat_pos_from_transform3d (tg )
204
- assert quaternion_equality (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = dtype , device = d ))
198
+ assert quaternion_close (rot , torch .tensor ([0.70710678 , 0. , 0. , 0.70710678 ], dtype = dtype , device = d ))
205
199
assert torch .allclose (pos , torch .tensor ([1.05 , 0.55 , 0.5 ], dtype = dtype , device = d ))
206
200
207
201
data = '<robot name="test_robot">' \
@@ -256,7 +250,7 @@ def test_fk_val():
256
250
tg = ret ['drive45' ]
257
251
pos , rot = quat_pos_from_transform3d (tg )
258
252
torch .set_printoptions (precision = 6 , sci_mode = False )
259
- assert quaternion_equality (rot , torch .tensor ([0.5 , 0.5 , - 0.5 , 0.5 ], dtype = torch .float64 ), rtol = 1e-4 )
253
+ assert quaternion_close (rot , torch .tensor ([0.5 , 0.5 , - 0.5 , 0.5 ], dtype = torch .float64 ))
260
254
assert torch .allclose (pos , torch .tensor ([- 0.225692 , 0.259045 , 0.262139 ], dtype = torch .float64 ))
261
255
262
256
0 commit comments