Skip to content

Commit d5a9788

Browse files
committed
Compare orientations in quaternion space to avoid discontinuities
Greatly improve convergence rate and total success chance
1 parent 67b111e commit d5a9788

File tree

1 file changed

+19
-11
lines changed
  • src/pytorch_kinematics

1 file changed

+19
-11
lines changed

src/pytorch_kinematics/ik.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def config_sampling_method(num_configs):
8181

8282

8383
class LineSearch:
84-
def do_line_search(self, chain, q, dq, target_pos, target_rot_rpy, initial_dx, problem_remaining=None):
84+
def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, problem_remaining=None):
8585
raise NotImplementedError()
8686

8787

@@ -92,7 +92,7 @@ def __init__(self, max_lr=1.0, decrease_factor=0.5, max_iterations=5, sufficient
9292
self.max_iterations = max_iterations
9393
self.sufficient_decrease = sufficient_decrease
9494

95-
def do_line_search(self, chain, q, dq, target_pos, target_rot_rpy, initial_dx, problem_remaining=None):
95+
def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, problem_remaining=None):
9696
N = target_pos.shape[0]
9797
NM = q.shape[0]
9898
M = NM // N
@@ -112,7 +112,7 @@ def do_line_search(self, chain, q, dq, target_pos, target_rot_rpy, initial_dx, p
112112
# evaluate the error
113113
m = chain.forward_kinematics(q_new).get_matrix()
114114
m = m.view(-1, M, 4, 4)
115-
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_rot_rpy)
115+
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_wxyz)
116116
err_new = dx.squeeze().norm(dim=-1)
117117
# check if it's better
118118
improvement = err - err_new
@@ -228,20 +228,28 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
228228
raise NotImplementedError()
229229

230230

231-
def delta_pose(m: torch.tensor, target_pos, target_rot_rpy):
231+
def delta_pose(m: torch.tensor, target_pos, target_wxyz):
232232
"""
233233
Determine the error in position and rotation between the given poses and the target poses
234234
235235
:param m: (N x M x 4 x 4) tensor of homogenous transforms
236236
:param target_pos:
237-
:param target_rot_rpy:
237+
:param target_wxyz: target orientation represented in unit quaternion
238238
:return: (N*M, 6, 1) tensor of delta pose (dx, dy, dz, droll, dpitch, dyaw)
239239
"""
240240
pos_diff = target_pos.unsqueeze(1) - m[:, :, :3, 3]
241241
pos_diff = pos_diff.view(-1, 3, 1)
242-
rot_diff = target_rot_rpy.unsqueeze(1) - rotation_conversions.matrix_to_euler_angles(m[:, :, :3, :3],
243-
"XYZ")
244-
rot_diff = rot_diff.view(-1, 3, 1)
242+
cur_wxyz = rotation_conversions.matrix_to_quaternion(m[:, :, :3, :3])
243+
244+
# quaternion that rotates from the current orientation to the desired orientation
245+
# inverse for unit quaternion is the conjugate
246+
diff_wxyz = rotation_conversions.quaternion_multiply(target_wxyz.unsqueeze(1),
247+
rotation_conversions.quaternion_invert(cur_wxyz))
248+
# angular velocity vector needed to correct the orientation
249+
# if time is considered, should divide by \delta t, but doing it iteratively we can choose delta t to be 1
250+
diff_axis_angle = rotation_conversions.quaternion_to_axis_angle(diff_wxyz)
251+
252+
rot_diff = diff_axis_angle.view(-1, 3, 1)
245253

246254
dx = torch.cat((pos_diff, rot_diff), dim=1)
247255
return dx, pos_diff, rot_diff
@@ -273,7 +281,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
273281
target_pos = target[:, :3, 3]
274282
# jacobian gives angular rotation about x,y,z axis of the base frame
275283
# convert target rot to desired rotation about x,y,z
276-
target_rot_rpy = rotation_conversions.matrix_to_euler_angles(target[:, :3, :3], "XYZ")
284+
target_wxyz = rotation_conversions.matrix_to_quaternion(target[:, :3, :3])
277285

278286
sol = IKSolution(self.dof, M, self.num_retries, self.pos_tolerance, self.rot_tolerance, device=self.device)
279287

@@ -306,7 +314,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
306314
J, m = self.chain.jacobian(q, ret_eef_pose=True)
307315
# unflatten to broadcast with goal
308316
m = m.view(-1, self.num_retries, 4, 4)
309-
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_rot_rpy)
317+
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_wxyz)
310318

311319
# damped least squares method
312320
# lambda^2*I (lambda^2 is regularization)
@@ -321,7 +329,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
321329
else:
322330
with torch.no_grad():
323331
if self.line_search is not None:
324-
lr, improvement = self.line_search.do_line_search(self.chain, q, dq, target_pos, target_rot_rpy,
332+
lr, improvement = self.line_search.do_line_search(self.chain, q, dq, target_pos, target_wxyz,
325333
dx, problem_remaining=sol.remaining)
326334
lr = lr.unsqueeze(1)
327335
else:

0 commit comments

Comments
 (0)