@@ -252,6 +252,19 @@ def apply_mask(mask, *args):
252
252
253
253
254
254
class PseudoInverseIK (InverseKinematics ):
255
+ def compute_dq (self , J , dx ):
256
+ # lambda^2*I (lambda^2 is regularization)
257
+ reg = self .regularlization * torch .eye (6 , device = self .device , dtype = self .dtype )
258
+
259
+ # JJ^T + lambda^2*I (lambda^2 is regularization)
260
+ tmpA = J @ J .transpose (1 , 2 ) + reg
261
+ # (JJ^T + lambda^2I) A = dx
262
+ # A = (JJ^T + lambda^2I)^-1 dx
263
+ A = torch .linalg .solve (tmpA , dx )
264
+ # dq = J^T (JJ^T + lambda^2I)^-1 dx
265
+ dq = J .transpose (1 , 2 ) @ A
266
+ return dq
267
+
255
268
def solve (self , target_poses : Transform3d ) -> IKSolution :
256
269
target = target_poses .get_matrix ()
257
270
@@ -296,12 +309,8 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
296
309
dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_rot_rpy )
297
310
298
311
# damped least squares method
299
- # JJ^T + lambda^2*I (lambda^2 is regularization)
300
- tmpA = J @ J .transpose (1 , 2 ) + self .regularlization * torch .eye (6 , device = self .device , dtype = self .dtype )
301
- # (JJ^T + lambda^2I) A = dx
302
- A = torch .linalg .solve (tmpA , dx )
303
- # dq = J^T (JJ^T + lambda^2I)^-1 dx
304
- dq = J .transpose (1 , 2 ) @ A
312
+ # lambda^2*I (lambda^2 is regularization)
313
+ dq = self .compute_dq (J , dx )
305
314
dq = dq .squeeze (2 )
306
315
307
316
improvement = None
@@ -381,3 +390,28 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
381
390
if i == self .max_iterations - 1 :
382
391
sol .update (q , self .err_all , use_keep_mask = False )
383
392
return sol
393
+
394
+
395
+ class PseudoInverseIKWithSVD (PseudoInverseIK ):
396
+ # generally slower, but allows for selective damping if needed
397
+ def compute_dq (self , J , dx ):
398
+ # reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
399
+ U , D , Vh = torch .linalg .svd (J )
400
+ m = D .shape [1 ]
401
+
402
+ # tmpA = U @ (D @ D.transpose(1, 2) + reg) @ U.transpose(1, 2)
403
+ # singular_val = torch.diagonal(D)
404
+
405
+ denom = D ** 2 + self .regularlization
406
+ prod = D / denom
407
+ # J^T (JJ^T + lambda^2I)^-1 = V @ (D @ D^T + lambda^2I)^-1 @ U^T = sum_i (d_i / (d_i^2 + lambda^2) v_i @ u_i^T)
408
+ # should be equivalent to damped least squares
409
+ inverted = torch .diag_embed (prod )
410
+
411
+ # drop columns from V
412
+ Vh = Vh [:, :m , :]
413
+ total = Vh .transpose (1 , 2 ) @ inverted @ U .transpose (1 , 2 )
414
+
415
+ # dq = J^T (JJ^T + lambda^2I)^-1 dx
416
+ dq = total @ dx
417
+ return dq
0 commit comments