@@ -81,7 +81,7 @@ def config_sampling_method(num_configs):
81
81
82
82
83
83
class LineSearch :
84
- def do_line_search (self , chain , q , dq , target_pos , target_rot_rpy , initial_dx , problem_remaining = None ):
84
+ def do_line_search (self , chain , q , dq , target_pos , target_wxyz , initial_dx , problem_remaining = None ):
85
85
raise NotImplementedError ()
86
86
87
87
@@ -92,7 +92,7 @@ def __init__(self, max_lr=1.0, decrease_factor=0.5, max_iterations=5, sufficient
92
92
self .max_iterations = max_iterations
93
93
self .sufficient_decrease = sufficient_decrease
94
94
95
- def do_line_search (self , chain , q , dq , target_pos , target_rot_rpy , initial_dx , problem_remaining = None ):
95
+ def do_line_search (self , chain , q , dq , target_pos , target_wxyz , initial_dx , problem_remaining = None ):
96
96
N = target_pos .shape [0 ]
97
97
NM = q .shape [0 ]
98
98
M = NM // N
@@ -112,7 +112,7 @@ def do_line_search(self, chain, q, dq, target_pos, target_rot_rpy, initial_dx, p
112
112
# evaluate the error
113
113
m = chain .forward_kinematics (q_new ).get_matrix ()
114
114
m = m .view (- 1 , M , 4 , 4 )
115
- dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_rot_rpy )
115
+ dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_wxyz )
116
116
err_new = dx .squeeze ().norm (dim = - 1 )
117
117
# check if it's better
118
118
improvement = err - err_new
@@ -228,20 +228,28 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
228
228
raise NotImplementedError ()
229
229
230
230
231
- def delta_pose (m : torch .tensor , target_pos , target_rot_rpy ):
231
+ def delta_pose (m : torch .tensor , target_pos , target_wxyz ):
232
232
"""
233
233
Determine the error in position and rotation between the given poses and the target poses
234
234
235
235
:param m: (N x M x 4 x 4) tensor of homogenous transforms
236
236
:param target_pos:
237
- :param target_rot_rpy:
237
+ :param target_wxyz: target orientation represented in unit quaternion
238
238
:return: (N*M, 6, 1) tensor of delta pose (dx, dy, dz, droll, dpitch, dyaw)
239
239
"""
240
240
pos_diff = target_pos .unsqueeze (1 ) - m [:, :, :3 , 3 ]
241
241
pos_diff = pos_diff .view (- 1 , 3 , 1 )
242
- rot_diff = target_rot_rpy .unsqueeze (1 ) - rotation_conversions .matrix_to_euler_angles (m [:, :, :3 , :3 ],
243
- "XYZ" )
244
- rot_diff = rot_diff .view (- 1 , 3 , 1 )
242
+ cur_wxyz = rotation_conversions .matrix_to_quaternion (m [:, :, :3 , :3 ])
243
+
244
+ # quaternion that rotates from the current orientation to the desired orientation
245
+ # inverse for unit quaternion is the conjugate
246
+ diff_wxyz = rotation_conversions .quaternion_multiply (target_wxyz .unsqueeze (1 ),
247
+ rotation_conversions .quaternion_invert (cur_wxyz ))
248
+ # angular velocity vector needed to correct the orientation
249
+ # if time is considered, should divide by \delta t, but doing it iteratively we can choose delta t to be 1
250
+ diff_axis_angle = rotation_conversions .quaternion_to_axis_angle (diff_wxyz )
251
+
252
+ rot_diff = diff_axis_angle .view (- 1 , 3 , 1 )
245
253
246
254
dx = torch .cat ((pos_diff , rot_diff ), dim = 1 )
247
255
return dx , pos_diff , rot_diff
@@ -273,7 +281,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
273
281
target_pos = target [:, :3 , 3 ]
274
282
# jacobian gives angular rotation about x,y,z axis of the base frame
275
283
# convert target rot to desired rotation about x,y,z
276
- target_rot_rpy = rotation_conversions .matrix_to_euler_angles (target [:, :3 , :3 ], "XYZ" )
284
+ target_wxyz = rotation_conversions .matrix_to_quaternion (target [:, :3 , :3 ])
277
285
278
286
sol = IKSolution (self .dof , M , self .num_retries , self .pos_tolerance , self .rot_tolerance , device = self .device )
279
287
@@ -306,7 +314,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
306
314
J , m = self .chain .jacobian (q , ret_eef_pose = True )
307
315
# unflatten to broadcast with goal
308
316
m = m .view (- 1 , self .num_retries , 4 , 4 )
309
- dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_rot_rpy )
317
+ dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_wxyz )
310
318
311
319
# damped least squares method
312
320
# lambda^2*I (lambda^2 is regularization)
@@ -321,7 +329,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
321
329
else :
322
330
with torch .no_grad ():
323
331
if self .line_search is not None :
324
- lr , improvement = self .line_search .do_line_search (self .chain , q , dq , target_pos , target_rot_rpy ,
332
+ lr , improvement = self .line_search .do_line_search (self .chain , q , dq , target_pos , target_wxyz ,
325
333
dx , problem_remaining = sol .remaining )
326
334
lr = lr .unsqueeze (1 )
327
335
else :
0 commit comments