@@ -148,7 +148,7 @@ def __init__(self, serial_chain: SerialChain,
148
148
:param pos_tolerance: position tolerance in meters
149
149
:param rot_tolerance: rotation tolerance in radians
150
150
:param retry_configs: (M, DOF) tensor of initial configs to try for each problem; leave as None to sample
151
- :param num_retries: number, M, of random initial configs to try for that problem
151
+ :param num_retries: number, M, of random initial configs to try for that problem; implemented with batching
152
152
:param joint_limits: (DOF, 2) tensor of joint limits (min, max) for each joint in radians
153
153
:param config_sampling_method: either "uniform" or "gaussian" or a function that takes in the number of configs
154
154
:param max_iterations: maximum number of iterations to run
@@ -295,8 +295,12 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
295
295
m = m .view (- 1 , self .num_retries , 4 , 4 )
296
296
dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_rot_rpy )
297
297
298
+ # damped least squares method
299
+ # JJ^T + lambda^2*I (lambda^2 is regularization)
298
300
tmpA = J @ J .transpose (1 , 2 ) + self .regularlization * torch .eye (6 , device = self .device , dtype = self .dtype )
301
+ # (JJ^T + lambda^2I) A = dx
299
302
A = torch .linalg .solve (tmpA , dx )
303
+ # dq = J^T (JJ^T + lambda^2I)^-1 dx
300
304
dq = J .transpose (1 , 2 ) @ A
301
305
dq = dq .squeeze (2 )
302
306
0 commit comments