Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions embodichain/lab/sim/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ class SolverCfg:
def init_solver(self, device: torch.device, **kwargs) -> "BaseSolver":
pass

def _get_tcp_as_numpy(self) -> np.ndarray:
"""Convert TCP to numpy array.

This helper method handles the conversion of TCP from torch.Tensor to numpy
if needed. Used by subclass init_solver methods to set TCP on the solver.

Returns:
np.ndarray: The TCP as a numpy array.
"""
if isinstance(self.tcp, torch.Tensor):
return self.tcp.cpu().numpy()
return self.tcp

@classmethod
def from_dict(cls, init_dict: Dict[str, Any]) -> "SolverCfg":
"""Initialize the configuration from a dictionary."""
Expand Down
6 changes: 1 addition & 5 deletions embodichain/lab/sim/solvers/differential_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ def init_solver(
)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp
solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down
6 changes: 1 addition & 5 deletions embodichain/lab/sim/solvers/opw_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ def init_solver(
solver = OPWSolver(cfg=self, device=device, **kwargs)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp
solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down
57 changes: 9 additions & 48 deletions embodichain/lab/sim/solvers/pink_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
lazy_import_pinocchio,
lazy_import_pink,
)
from embodichain.lab.sim.utility.solver_utils import (
build_reduced_pinocchio_robot,
compute_pinocchio_fk,
)

from embodichain.utils import configclass, logger
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
Expand Down Expand Up @@ -111,12 +115,7 @@ def init_solver(self, **kwargs) -> "PinkSolver":
solver = PinkSolver(cfg=self, **kwargs)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp

solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down Expand Up @@ -162,7 +161,7 @@ def __init__(self, cfg: PinkSolverCfg, **kwargs):
) # Degrees of freedom of robot joints

# Get reduced robot model
self.robot = self._get_reduce_robot()
self.robot = build_reduced_pinocchio_robot(self.entire_robot, self.joint_names)

# Initialize Pink configuration
self.pink_cfg = self.pink.configuration.Configuration(
Expand Down Expand Up @@ -207,26 +206,6 @@ def __init__(self, cfg: PinkSolverCfg, **kwargs):
self.dexsim_to_pink_ordering = None
self.pink_to_dexsim_ordering = None

def _get_reduce_robot(self) -> "pin.RobotWrapper":
"""Build a reduced robot model by locking all joints except those in self.joint_names.

Returns:
pin.RobotWrapper: The reduced robot model with specified joints unlocked.
"""
pink_joint_names = self.entire_robot.model.names.tolist()

# Lock all joints except those in self.joint_names and 'universe'
fixed_joint_names = [
name
for name in pink_joint_names
if name not in self.joint_names and name != "universe"
]

reduced_robot = self.entire_robot.buildReducedRobot(
list_of_joints_to_lock=fixed_joint_names
)
return reduced_robot

def reorder_array(
self, input_array: List[float], reordering_array: List[int]
) -> List[float]:
Expand Down Expand Up @@ -393,25 +372,7 @@ def _get_fk(
Returns:
torch.Tensor: The homogeneous transformation matrix (4x4) of the end-effector (after applying TCP).
"""
if isinstance(qpos, torch.Tensor):
qpos_np = qpos.detach().cpu().numpy()
else:
qpos_np = np.array(qpos)

qpos_np = np.squeeze(qpos_np)
if qpos_np.ndim != 1:
raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}")

self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np)

# Retrieve the pose of the specified link
frame_index = self.robot.model.getFrameId(self.end_link_name)
joint_index = self.robot.model.frames[frame_index].parent
xpos_se3 = self.robot.data.oMi.tolist()[joint_index]

xpos = np.eye(4)
xpos[:3, :3] = xpos_se3.rotation
xpos[:3, 3] = xpos_se3.translation.T

result = np.dot(xpos, self.tcp_xpos)
result = compute_pinocchio_fk(
self.pin, self.robot, qpos, self.end_link_name, self.tcp_xpos
)
return torch.from_numpy(result)
80 changes: 13 additions & 67 deletions embodichain/lab/sim/solvers/pinocchio_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
lazy_import_casadi,
# lazy_import_pinocchio_casadi,
)
from embodichain.lab.sim.utility.solver_utils import (
build_reduced_pinocchio_robot,
validate_iteration_params,
compute_pinocchio_fk,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -72,12 +77,7 @@ def init_solver(self, **kwargs) -> "PinocchioSolver":
solver = PinocchioSolver(cfg=self, **kwargs)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp

solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs):
) # Degrees of freedom of robot joints

# Build reduced robot model (only relevant joints unlocked)
self.robot = self._get_reduce_robot()
self.robot = build_reduced_pinocchio_robot(self.entire_robot, self.joint_names)
self.joint_names = self.robot.model.names.tolist()[
1:
] # Exclude 'universe' joint
Expand Down Expand Up @@ -221,26 +221,6 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs):
self.root_base_xpos[:3, :3] = root_base_pose.rotation
self.root_base_xpos[:3, 3] = root_base_pose.translation.T

def _get_reduce_robot(self) -> "pin.RobotWrapper":
"""Build a reduced robot model by locking all joints except those in self.joint_names.

Returns:
pin.RobotWrapper: The reduced robot model with specified joints unlocked.
"""
all_joint_names = self.entire_robot.model.names.tolist()

# Lock all joints except those in self.joint_names and 'universe'
fixed_joint_names = [
name
for name in all_joint_names
if name not in self.joint_names and name != "universe"
]

reduced_robot = self.entire_robot.buildReducedRobot(
list_of_joints_to_lock=fixed_joint_names
)
return reduced_robot

def set_tcp(self, tcp: np.ndarray):
self.tcp = tcp

Expand Down Expand Up @@ -290,25 +270,10 @@ def set_iteration_params(
Returns:
bool: True if all parameters are valid and set, False otherwise.
"""
# TODO: Check which parameters are no longer needed.
# Validate parameters
if pos_eps <= 0:
logger.log_warning("Pos epsilon must be positive.")
return False
if rot_eps <= 0:
logger.log_warning("Rot epsilon must be positive.")
return False
if max_iterations <= 0:
logger.log_warning("Max iterations must be positive.")
return False
if dt <= 0:
logger.log_warning("Time step must be positive.")
return False
if damp < 0:
logger.log_warning("Damping factor must be non-negative.")
return False
if num_samples <= 0:
logger.log_warning("Number of samples must be positive.")
if not validate_iteration_params(
pos_eps, rot_eps, max_iterations, dt, damp, num_samples
):
return False

# Set parameters if all are valid
Expand Down Expand Up @@ -620,25 +585,6 @@ def _get_fk(
Returns:
np.ndarray: The resulting end-effector pose as a (4, 4) homogeneous transformation matrix.
"""
if isinstance(qpos, torch.Tensor):
qpos_np = qpos.detach().cpu().numpy()
else:
qpos_np = np.array(qpos)

qpos_np = np.squeeze(qpos_np)
if qpos_np.ndim != 1:
raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}")

self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np)

# Retrieve the pose of the specified link
frame_index = self.robot.model.getFrameId(self.end_link_name)
joint_index = self.robot.model.frames[frame_index].parent
xpos_se3 = self.robot.data.oMi.tolist()[joint_index]

xpos = np.eye(4)
xpos[:3, :3] = xpos_se3.rotation
xpos[:3, 3] = xpos_se3.translation.T

result = np.dot(xpos, self.tcp_xpos)
return result
return compute_pinocchio_fk(
self.pin, self.robot, qpos, self.end_link_name, self.tcp_xpos
)
27 changes: 5 additions & 22 deletions embodichain/lab/sim/solvers/pytorch_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from embodichain.utils import configclass, logger
from embodichain.lab.sim.solvers import SolverCfg, BaseSolver
from embodichain.lab.sim.solvers.qpos_seed_sampler import QposSeedSampler
from embodichain.lab.sim.utility.solver_utils import validate_iteration_params

if TYPE_CHECKING:
from typing import Self
Expand Down Expand Up @@ -90,11 +91,7 @@ def init_solver(
solver = PytorchSolver(cfg=self, device=device, **kwargs)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp
solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down Expand Up @@ -227,23 +224,9 @@ def set_iteration_params(
bool: True if all parameters are valid and set, False otherwise.
"""
# Validate parameters
if pos_eps <= 0:
logger.log_warning("Pos epsilon must be positive.")
return False
if rot_eps <= 0:
logger.log_warning("Rot epsilon must be positive.")
return False
if max_iterations <= 0:
logger.log_warning("Max iterations must be positive.")
return False
if dt <= 0:
logger.log_warning("Time step must be positive.")
return False
if damp < 0:
logger.log_warning("Damping factor must be non-negative.")
return False
if num_samples <= 0:
logger.log_warning("Number of samples must be positive.")
if not validate_iteration_params(
pos_eps, rot_eps, max_iterations, dt, damp, num_samples
):
return False

# Set parameters if all are valid
Expand Down
6 changes: 1 addition & 5 deletions embodichain/lab/sim/solvers/srs_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ def init_solver(
solver = SRSSolver(cfg=self, num_envs=num_envs, device=device, **kwargs)

# Set the Tool Center Point (TCP) for the solver
if isinstance(self.tcp, torch.Tensor):
tcp = self.tcp.cpu().numpy()
else:
tcp = self.tcp
solver.set_tcp(tcp)
solver.set_tcp(self._get_tcp_as_numpy())

return solver

Expand Down
Loading