Skip to content

Commit be5a951

Browse files
Copilotyuecideng
andauthored
Refactor duplicated code in solver modules (#20)
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: yuecideng <[email protected]>
1 parent 3800e83 commit be5a951

File tree

8 files changed

+174
-153
lines changed

8 files changed

+174
-153
lines changed

embodichain/lab/sim/solvers/base_solver.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ class SolverCfg:
7676
def init_solver(self, device: torch.device, **kwargs) -> "BaseSolver":
7777
pass
7878

79+
def _get_tcp_as_numpy(self) -> np.ndarray:
80+
"""Convert TCP to numpy array.
81+
82+
This helper method handles the conversion of TCP from torch.Tensor to numpy
83+
if needed. Used by subclass init_solver methods to set TCP on the solver.
84+
85+
Returns:
86+
np.ndarray: The TCP as a numpy array.
87+
"""
88+
if isinstance(self.tcp, torch.Tensor):
89+
return self.tcp.cpu().numpy()
90+
return self.tcp
91+
7992
@classmethod
8093
def from_dict(cls, init_dict: Dict[str, Any]) -> "SolverCfg":
8194
"""Initialize the configuration from a dictionary."""

embodichain/lab/sim/solvers/differential_solver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,7 @@ def init_solver(
8989
)
9090

9191
# Set the Tool Center Point (TCP) for the solver
92-
if isinstance(self.tcp, torch.Tensor):
93-
tcp = self.tcp.cpu().numpy()
94-
else:
95-
tcp = self.tcp
96-
solver.set_tcp(tcp)
92+
solver.set_tcp(self._get_tcp_as_numpy())
9793

9894
return solver
9995

embodichain/lab/sim/solvers/opw_solver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,7 @@ def init_solver(
9595
solver = OPWSolver(cfg=self, device=device, **kwargs)
9696

9797
# Set the Tool Center Point (TCP) for the solver
98-
if isinstance(self.tcp, torch.Tensor):
99-
tcp = self.tcp.cpu().numpy()
100-
else:
101-
tcp = self.tcp
102-
solver.set_tcp(tcp)
98+
solver.set_tcp(self._get_tcp_as_numpy())
10399

104100
return solver
105101

embodichain/lab/sim/solvers/pink_solver.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
lazy_import_pinocchio,
2525
lazy_import_pink,
2626
)
27+
from embodichain.lab.sim.utility.solver_utils import (
28+
build_reduced_pinocchio_robot,
29+
compute_pinocchio_fk,
30+
)
2731

2832
from embodichain.utils import configclass, logger
2933
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
@@ -111,12 +115,7 @@ def init_solver(self, **kwargs) -> "PinkSolver":
111115
solver = PinkSolver(cfg=self, **kwargs)
112116

113117
# Set the Tool Center Point (TCP) for the solver
114-
if isinstance(self.tcp, torch.Tensor):
115-
tcp = self.tcp.cpu().numpy()
116-
else:
117-
tcp = self.tcp
118-
119-
solver.set_tcp(tcp)
118+
solver.set_tcp(self._get_tcp_as_numpy())
120119

121120
return solver
122121

@@ -162,7 +161,7 @@ def __init__(self, cfg: PinkSolverCfg, **kwargs):
162161
) # Degrees of freedom of robot joints
163162

164163
# Get reduced robot model
165-
self.robot = self._get_reduce_robot()
164+
self.robot = build_reduced_pinocchio_robot(self.entire_robot, self.joint_names)
166165

167166
# Initialize Pink configuration
168167
self.pink_cfg = self.pink.configuration.Configuration(
@@ -207,26 +206,6 @@ def __init__(self, cfg: PinkSolverCfg, **kwargs):
207206
self.dexsim_to_pink_ordering = None
208207
self.pink_to_dexsim_ordering = None
209208

210-
def _get_reduce_robot(self) -> "pin.RobotWrapper":
211-
"""Build a reduced robot model by locking all joints except those in self.joint_names.
212-
213-
Returns:
214-
pin.RobotWrapper: The reduced robot model with specified joints unlocked.
215-
"""
216-
pink_joint_names = self.entire_robot.model.names.tolist()
217-
218-
# Lock all joints except those in self.joint_names and 'universe'
219-
fixed_joint_names = [
220-
name
221-
for name in pink_joint_names
222-
if name not in self.joint_names and name != "universe"
223-
]
224-
225-
reduced_robot = self.entire_robot.buildReducedRobot(
226-
list_of_joints_to_lock=fixed_joint_names
227-
)
228-
return reduced_robot
229-
230209
def reorder_array(
231210
self, input_array: List[float], reordering_array: List[int]
232211
) -> List[float]:
@@ -393,25 +372,7 @@ def _get_fk(
393372
Returns:
394373
torch.Tensor: The homogeneous transformation matrix (4x4) of the end-effector (after applying TCP).
395374
"""
396-
if isinstance(qpos, torch.Tensor):
397-
qpos_np = qpos.detach().cpu().numpy()
398-
else:
399-
qpos_np = np.array(qpos)
400-
401-
qpos_np = np.squeeze(qpos_np)
402-
if qpos_np.ndim != 1:
403-
raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}")
404-
405-
self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np)
406-
407-
# Retrieve the pose of the specified link
408-
frame_index = self.robot.model.getFrameId(self.end_link_name)
409-
joint_index = self.robot.model.frames[frame_index].parent
410-
xpos_se3 = self.robot.data.oMi.tolist()[joint_index]
411-
412-
xpos = np.eye(4)
413-
xpos[:3, :3] = xpos_se3.rotation
414-
xpos[:3, 3] = xpos_se3.translation.T
415-
416-
result = np.dot(xpos, self.tcp_xpos)
375+
result = compute_pinocchio_fk(
376+
self.pin, self.robot, qpos, self.end_link_name, self.tcp_xpos
377+
)
417378
return torch.from_numpy(result)

embodichain/lab/sim/solvers/pinocchio_solver.py

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
lazy_import_casadi,
3030
# lazy_import_pinocchio_casadi,
3131
)
32+
from embodichain.lab.sim.utility.solver_utils import (
33+
build_reduced_pinocchio_robot,
34+
validate_iteration_params,
35+
compute_pinocchio_fk,
36+
)
3237

3338

3439
if TYPE_CHECKING:
@@ -72,12 +77,7 @@ def init_solver(self, **kwargs) -> "PinocchioSolver":
7277
solver = PinocchioSolver(cfg=self, **kwargs)
7378

7479
# Set the Tool Center Point (TCP) for the solver
75-
if isinstance(self.tcp, torch.Tensor):
76-
tcp = self.tcp.cpu().numpy()
77-
else:
78-
tcp = self.tcp
79-
80-
solver.set_tcp(tcp)
80+
solver.set_tcp(self._get_tcp_as_numpy())
8181

8282
return solver
8383

@@ -121,7 +121,7 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs):
121121
) # Degrees of freedom of robot joints
122122

123123
# Build reduced robot model (only relevant joints unlocked)
124-
self.robot = self._get_reduce_robot()
124+
self.robot = build_reduced_pinocchio_robot(self.entire_robot, self.joint_names)
125125
self.joint_names = self.robot.model.names.tolist()[
126126
1:
127127
] # Exclude 'universe' joint
@@ -221,26 +221,6 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs):
221221
self.root_base_xpos[:3, :3] = root_base_pose.rotation
222222
self.root_base_xpos[:3, 3] = root_base_pose.translation.T
223223

224-
def _get_reduce_robot(self) -> "pin.RobotWrapper":
225-
"""Build a reduced robot model by locking all joints except those in self.joint_names.
226-
227-
Returns:
228-
pin.RobotWrapper: The reduced robot model with specified joints unlocked.
229-
"""
230-
all_joint_names = self.entire_robot.model.names.tolist()
231-
232-
# Lock all joints except those in self.joint_names and 'universe'
233-
fixed_joint_names = [
234-
name
235-
for name in all_joint_names
236-
if name not in self.joint_names and name != "universe"
237-
]
238-
239-
reduced_robot = self.entire_robot.buildReducedRobot(
240-
list_of_joints_to_lock=fixed_joint_names
241-
)
242-
return reduced_robot
243-
244224
def set_tcp(self, tcp: np.ndarray):
245225
self.tcp = tcp
246226

@@ -290,25 +270,10 @@ def set_iteration_params(
290270
Returns:
291271
bool: True if all parameters are valid and set, False otherwise.
292272
"""
293-
# TODO: Check which parameters are no longer needed.
294273
# Validate parameters
295-
if pos_eps <= 0:
296-
logger.log_warning("Pos epsilon must be positive.")
297-
return False
298-
if rot_eps <= 0:
299-
logger.log_warning("Rot epsilon must be positive.")
300-
return False
301-
if max_iterations <= 0:
302-
logger.log_warning("Max iterations must be positive.")
303-
return False
304-
if dt <= 0:
305-
logger.log_warning("Time step must be positive.")
306-
return False
307-
if damp < 0:
308-
logger.log_warning("Damping factor must be non-negative.")
309-
return False
310-
if num_samples <= 0:
311-
logger.log_warning("Number of samples must be positive.")
274+
if not validate_iteration_params(
275+
pos_eps, rot_eps, max_iterations, dt, damp, num_samples
276+
):
312277
return False
313278

314279
# Set parameters if all are valid
@@ -620,25 +585,6 @@ def _get_fk(
620585
Returns:
621586
np.ndarray: The resulting end-effector pose as a (4, 4) homogeneous transformation matrix.
622587
"""
623-
if isinstance(qpos, torch.Tensor):
624-
qpos_np = qpos.detach().cpu().numpy()
625-
else:
626-
qpos_np = np.array(qpos)
627-
628-
qpos_np = np.squeeze(qpos_np)
629-
if qpos_np.ndim != 1:
630-
raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}")
631-
632-
self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np)
633-
634-
# Retrieve the pose of the specified link
635-
frame_index = self.robot.model.getFrameId(self.end_link_name)
636-
joint_index = self.robot.model.frames[frame_index].parent
637-
xpos_se3 = self.robot.data.oMi.tolist()[joint_index]
638-
639-
xpos = np.eye(4)
640-
xpos[:3, :3] = xpos_se3.rotation
641-
xpos[:3, 3] = xpos_se3.translation.T
642-
643-
result = np.dot(xpos, self.tcp_xpos)
644-
return result
588+
return compute_pinocchio_fk(
589+
self.pin, self.robot, qpos, self.end_link_name, self.tcp_xpos
590+
)

embodichain/lab/sim/solvers/pytorch_solver.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from embodichain.utils import configclass, logger
2424
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
2525
from embodichain.lab.sim.solvers.qpos_seed_sampler import QposSeedSampler
26+
from embodichain.lab.sim.utility.solver_utils import validate_iteration_params
2627

2728
if TYPE_CHECKING:
2829
from typing import Self
@@ -90,11 +91,7 @@ def init_solver(
9091
solver = PytorchSolver(cfg=self, device=device, **kwargs)
9192

9293
# Set the Tool Center Point (TCP) for the solver
93-
if isinstance(self.tcp, torch.Tensor):
94-
tcp = self.tcp.cpu().numpy()
95-
else:
96-
tcp = self.tcp
97-
solver.set_tcp(tcp)
94+
solver.set_tcp(self._get_tcp_as_numpy())
9895

9996
return solver
10097

@@ -227,23 +224,9 @@ def set_iteration_params(
227224
bool: True if all parameters are valid and set, False otherwise.
228225
"""
229226
# Validate parameters
230-
if pos_eps <= 0:
231-
logger.log_warning("Pos epsilon must be positive.")
232-
return False
233-
if rot_eps <= 0:
234-
logger.log_warning("Rot epsilon must be positive.")
235-
return False
236-
if max_iterations <= 0:
237-
logger.log_warning("Max iterations must be positive.")
238-
return False
239-
if dt <= 0:
240-
logger.log_warning("Time step must be positive.")
241-
return False
242-
if damp < 0:
243-
logger.log_warning("Damping factor must be non-negative.")
244-
return False
245-
if num_samples <= 0:
246-
logger.log_warning("Number of samples must be positive.")
227+
if not validate_iteration_params(
228+
pos_eps, rot_eps, max_iterations, dt, damp, num_samples
229+
):
247230
return False
248231

249232
# Set parameters if all are valid

embodichain/lab/sim/solvers/srs_solver.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,7 @@ def init_solver(
9393
solver = SRSSolver(cfg=self, num_envs=num_envs, device=device, **kwargs)
9494

9595
# Set the Tool Center Point (TCP) for the solver
96-
if isinstance(self.tcp, torch.Tensor):
97-
tcp = self.tcp.cpu().numpy()
98-
else:
99-
tcp = self.tcp
100-
solver.set_tcp(tcp)
96+
solver.set_tcp(self._get_tcp_as_numpy())
10197

10298
return solver
10399

0 commit comments

Comments
 (0)