Skip to content

Commit 67b111e

Browse files
committed
Add SVD based IK solver
1 parent 90913be commit 67b111e

File tree

1 file changed

+40
-6
lines changed
  • src/pytorch_kinematics

1 file changed

+40
-6
lines changed

src/pytorch_kinematics/ik.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,19 @@ def apply_mask(mask, *args):
252252

253253

254254
class PseudoInverseIK(InverseKinematics):
255+
def compute_dq(self, J, dx):
256+
# lambda^2*I (lambda^2 is regularization)
257+
reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
258+
259+
# JJ^T + lambda^2*I (lambda^2 is regularization)
260+
tmpA = J @ J.transpose(1, 2) + reg
261+
# (JJ^T + lambda^2I) A = dx
262+
# A = (JJ^T + lambda^2I)^-1 dx
263+
A = torch.linalg.solve(tmpA, dx)
264+
# dq = J^T (JJ^T + lambda^2I)^-1 dx
265+
dq = J.transpose(1, 2) @ A
266+
return dq
267+
255268
def solve(self, target_poses: Transform3d) -> IKSolution:
256269
target = target_poses.get_matrix()
257270

@@ -296,12 +309,8 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
296309
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_rot_rpy)
297310

298311
# damped least squares method
299-
# JJ^T + lambda^2*I (lambda^2 is regularization)
300-
tmpA = J @ J.transpose(1, 2) + self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
301-
# (JJ^T + lambda^2I) A = dx
302-
A = torch.linalg.solve(tmpA, dx)
303-
# dq = J^T (JJ^T + lambda^2I)^-1 dx
304-
dq = J.transpose(1, 2) @ A
312+
# lambda^2*I (lambda^2 is regularization)
313+
dq = self.compute_dq(J, dx)
305314
dq = dq.squeeze(2)
306315

307316
improvement = None
@@ -381,3 +390,28 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
381390
if i == self.max_iterations - 1:
382391
sol.update(q, self.err_all, use_keep_mask=False)
383392
return sol
393+
394+
395+
class PseudoInverseIKWithSVD(PseudoInverseIK):
396+
# generally slower, but allows for selective damping if needed
397+
def compute_dq(self, J, dx):
398+
# reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
399+
U, D, Vh = torch.linalg.svd(J)
400+
m = D.shape[1]
401+
402+
# tmpA = U @ (D @ D.transpose(1, 2) + reg) @ U.transpose(1, 2)
403+
# singular_val = torch.diagonal(D)
404+
405+
denom = D ** 2 + self.regularlization
406+
prod = D / denom
407+
# J^T (JJ^T + lambda^2I)^-1 = V @ (D @ D^T + lambda^2I)^-1 @ U^T = sum_i (d_i / (d_i^2 + lambda^2) v_i @ u_i^T)
408+
# should be equivalent to damped least squares
409+
inverted = torch.diag_embed(prod)
410+
411+
# drop columns from V
412+
Vh = Vh[:, :m, :]
413+
total = Vh.transpose(1, 2) @ inverted @ U.transpose(1, 2)
414+
415+
# dq = J^T (JJ^T + lambda^2I)^-1 dx
416+
dq = total @ dx
417+
return dq

0 commit comments

Comments
 (0)