|
29 | 29 | lazy_import_casadi, |
30 | 30 | # lazy_import_pinocchio_casadi, |
31 | 31 | ) |
| 32 | +from embodichain.lab.sim.utility.solver_utils import ( |
| 33 | + build_reduced_pinocchio_robot, |
| 34 | + validate_iteration_params, |
| 35 | + compute_pinocchio_fk, |
| 36 | +) |
32 | 37 |
|
33 | 38 |
|
34 | 39 | if TYPE_CHECKING: |
@@ -72,12 +77,7 @@ def init_solver(self, **kwargs) -> "PinocchioSolver": |
72 | 77 | solver = PinocchioSolver(cfg=self, **kwargs) |
73 | 78 |
|
74 | 79 | # Set the Tool Center Point (TCP) for the solver |
75 | | - if isinstance(self.tcp, torch.Tensor): |
76 | | - tcp = self.tcp.cpu().numpy() |
77 | | - else: |
78 | | - tcp = self.tcp |
79 | | - |
80 | | - solver.set_tcp(tcp) |
| 80 | + solver.set_tcp(self._get_tcp_as_numpy()) |
81 | 81 |
|
82 | 82 | return solver |
83 | 83 |
|
@@ -121,7 +121,7 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs): |
121 | 121 | ) # Degrees of freedom of robot joints |
122 | 122 |
|
123 | 123 | # Build reduced robot model (only relevant joints unlocked) |
124 | | - self.robot = self._get_reduce_robot() |
| 124 | + self.robot = build_reduced_pinocchio_robot(self.entire_robot, self.joint_names) |
125 | 125 | self.joint_names = self.robot.model.names.tolist()[ |
126 | 126 | 1: |
127 | 127 | ] # Exclude 'universe' joint |
@@ -221,26 +221,6 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs): |
221 | 221 | self.root_base_xpos[:3, :3] = root_base_pose.rotation |
222 | 222 | self.root_base_xpos[:3, 3] = root_base_pose.translation.T |
223 | 223 |
|
224 | | - def _get_reduce_robot(self) -> "pin.RobotWrapper": |
225 | | - """Build a reduced robot model by locking all joints except those in self.joint_names. |
226 | | -
|
227 | | - Returns: |
228 | | - pin.RobotWrapper: The reduced robot model with specified joints unlocked. |
229 | | - """ |
230 | | - all_joint_names = self.entire_robot.model.names.tolist() |
231 | | - |
232 | | - # Lock all joints except those in self.joint_names and 'universe' |
233 | | - fixed_joint_names = [ |
234 | | - name |
235 | | - for name in all_joint_names |
236 | | - if name not in self.joint_names and name != "universe" |
237 | | - ] |
238 | | - |
239 | | - reduced_robot = self.entire_robot.buildReducedRobot( |
240 | | - list_of_joints_to_lock=fixed_joint_names |
241 | | - ) |
242 | | - return reduced_robot |
243 | | - |
244 | 224 | def set_tcp(self, tcp: np.ndarray): |
245 | 225 | self.tcp = tcp |
246 | 226 |
|
@@ -290,25 +270,10 @@ def set_iteration_params( |
290 | 270 | Returns: |
291 | 271 | bool: True if all parameters are valid and set, False otherwise. |
292 | 272 | """ |
293 | | - # TODO: Check which parameters are no longer needed. |
294 | 273 | # Validate parameters |
295 | | - if pos_eps <= 0: |
296 | | - logger.log_warning("Pos epsilon must be positive.") |
297 | | - return False |
298 | | - if rot_eps <= 0: |
299 | | - logger.log_warning("Rot epsilon must be positive.") |
300 | | - return False |
301 | | - if max_iterations <= 0: |
302 | | - logger.log_warning("Max iterations must be positive.") |
303 | | - return False |
304 | | - if dt <= 0: |
305 | | - logger.log_warning("Time step must be positive.") |
306 | | - return False |
307 | | - if damp < 0: |
308 | | - logger.log_warning("Damping factor must be non-negative.") |
309 | | - return False |
310 | | - if num_samples <= 0: |
311 | | - logger.log_warning("Number of samples must be positive.") |
| 274 | + if not validate_iteration_params( |
| 275 | + pos_eps, rot_eps, max_iterations, dt, damp, num_samples |
| 276 | + ): |
312 | 277 | return False |
313 | 278 |
|
314 | 279 | # Set parameters if all are valid |
@@ -620,25 +585,6 @@ def _get_fk( |
620 | 585 | Returns: |
621 | 586 | np.ndarray: The resulting end-effector pose as a (4, 4) homogeneous transformation matrix. |
622 | 587 | """ |
623 | | - if isinstance(qpos, torch.Tensor): |
624 | | - qpos_np = qpos.detach().cpu().numpy() |
625 | | - else: |
626 | | - qpos_np = np.array(qpos) |
627 | | - |
628 | | - qpos_np = np.squeeze(qpos_np) |
629 | | - if qpos_np.ndim != 1: |
630 | | - raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}") |
631 | | - |
632 | | - self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np) |
633 | | - |
634 | | - # Retrieve the pose of the specified link |
635 | | - frame_index = self.robot.model.getFrameId(self.end_link_name) |
636 | | - joint_index = self.robot.model.frames[frame_index].parent |
637 | | - xpos_se3 = self.robot.data.oMi.tolist()[joint_index] |
638 | | - |
639 | | - xpos = np.eye(4) |
640 | | - xpos[:3, :3] = xpos_se3.rotation |
641 | | - xpos[:3, 3] = xpos_se3.translation.T |
642 | | - |
643 | | - result = np.dot(xpos, self.tcp_xpos) |
644 | | - return result |
| 588 | + return compute_pinocchio_fk( |
| 589 | + self.pin, self.robot, qpos, self.end_link_name, self.tcp_xpos |
| 590 | + ) |
0 commit comments