Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions embodichain/lab/sim/solvers/differential_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# ----------------------------------------------------------------------------

import torch
from copy import deepcopy
from typing import Optional, Union, Tuple, Any, Literal, TYPE_CHECKING
from scipy.spatial.transform import Rotation

Expand Down Expand Up @@ -245,11 +244,13 @@ def get_ik(
current_xpos = self.get_fk(qpos_seed, to_matrix=True)

# Transform target_xpos by TCP
# Note: torch.as_tensor does not modify the input, so deepcopy is unnecessary
tcp_xpos = torch.as_tensor(
deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32
self.tcp_xpos, device=self.device, dtype=torch.float32
)
current_xpos = current_xpos @ torch.inverse(tcp_xpos)
compute_xpos = target_xpos @ torch.inverse(tcp_xpos)
tcp_xpos_inv = torch.inverse(tcp_xpos)
current_xpos = current_xpos @ tcp_xpos_inv
compute_xpos = target_xpos @ tcp_xpos_inv

# Ensure compute_xpos is a batch of matrices
if current_xpos.dim() == 2 and current_xpos.shape == (4, 4):
Expand Down
43 changes: 21 additions & 22 deletions embodichain/toolkits/graspkit/pg_grasp/antipodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,11 @@ def get_dis_arr(self, others) -> np.ndarray:
Returns:
np.ndarray: distance array
"""
other_num = len(others)
other_a = np.empty(shape=(other_num, 3), dtype=float)
other_b = np.empty(shape=(other_num, 3), dtype=float)
for i in range(other_num):
other_a[i] = others[i].point_a
other_b[i] = others[i].point_b
if not others:
return np.array([], dtype=float)
# Vectorized extraction of points using list comprehension and np.array
other_a = np.array([o.point_a for o in others], dtype=float)
other_b = np.array([o.point_b for o in others], dtype=float)
aa_dis = np.linalg.norm(other_a - self.point_a, axis=1)
ab_dis = np.linalg.norm(other_a - self.point_b, axis=1)
ba_dis = np.linalg.norm(other_b - self.point_a, axis=1)
Expand Down Expand Up @@ -382,22 +381,26 @@ def _generate_cache(
# self.antipodal_visual(nms_antipodal_list)
grasp_num = grasp_poses.shape[0]
logger.log_debug(f"Write {grasp_num} poses to pickle file {cache_file}.")
grasp_list = [None for i in range(grasp_num)]
for i in range(grasp_num):
grasp_list[i] = AntipodalGrasp(grasp_poses[i], open_length[i], score[i])
# Use list comprehension for efficient list construction
grasp_list = [
AntipodalGrasp(grasp_poses[i], open_length[i], score[i])
for i in range(grasp_num)
]
return grasp_list

def _load_cache(self, cache_file: str):
data_dict = pickle.load(open(cache_file, "rb"))
grasp_num = data_dict["grasp_poses"].shape[0]
logger.log_debug(f"Load {grasp_num} poses from pickle file {cache_file}.")
grasp_list = [None for i in range(grasp_num)]
for i in range(grasp_num):
grasp_list[i] = AntipodalGrasp(
# Use list comprehension for efficient list construction
grasp_list = [
AntipodalGrasp(
data_dict["grasp_poses"][i],
data_dict["open_length"][i],
data_dict["score"][i],
)
for i in range(grasp_num)
]
return grasp_list

def _get_pc_size(self, vertices):
Expand Down Expand Up @@ -521,13 +524,11 @@ def select_grasp(
"""
grasp_num = len(self._grasp_list)
all_idx = np.arange(grasp_num)
grasp_poses = np.empty(shape=(grasp_num, 4, 4), dtype=float)
scores = np.empty(shape=(grasp_num,), dtype=float)
position = grasp_poses[:, :3, 3]

for i in range(grasp_num):
grasp_poses[i] = self._grasp_list[i].pose
scores[i] = self._grasp_list[i].score
# Vectorized extraction of poses and scores using list comprehension
grasp_poses = np.array([g.pose for g in self._grasp_list], dtype=float)
scores = np.array([g.score for g in self._grasp_list], dtype=float)
position = grasp_poses[:, :3, 3]

# mask acoording to table up direction
grasp_z = grasp_poses[:, :3, 2]
Expand Down Expand Up @@ -557,10 +558,8 @@ def select_grasp(
best_valid_idx = sort_valid_idx[:result_num]
best_idx = valid_id[best_valid_idx]

result_grasp_list = []
for idx in best_idx:
result_grasp_list.append(self._grasp_list[idx])
return result_grasp_list
# Use list comprehension for faster list construction
return [self._grasp_list[idx] for idx in best_idx]

def _antipodal_nms(
self, antipodal_list: List[Antipodal], nms_ratio: float = 0.02
Expand Down
39 changes: 3 additions & 36 deletions embodichain/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import math
import warnings
import torch
import numpy as np
import torch.nn.functional
Expand Down Expand Up @@ -394,10 +395,8 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
Reference:
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
# Use torch.where for vectorized operation instead of indexed assignment
return torch.where(x > 0, torch.sqrt(x), torch.zeros_like(x))


@torch.jit.script
Expand Down Expand Up @@ -508,38 +507,6 @@ def trans_matrix_to_xyz_quat(matrix: torch.Tensor) -> torch.Tensor:
return vec


@torch.jit.script
def quat_from_euler_xyz(
roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor
) -> torch.Tensor:
"""Convert rotations given as Euler angles in radians to Quaternions.

Note:
The euler angles are assumed in XYZ convention.

Args:
roll: Rotation around x-axis (in radians). Shape is (N,).
pitch: Rotation around y-axis (in radians). Shape is (N,).
yaw: Rotation around z-axis (in radians). Shape is (N,).

Returns:
The quaternion in (w, x, y, z). Shape is (N, 4).
"""
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
# compute quaternion
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp

return torch.stack([qw, qx, qy, qz], dim=-1)


def _axis_angle_rotation(
axis: Literal["X", "Y", "Z"], angle: torch.Tensor
) -> torch.Tensor:
Expand Down
47 changes: 17 additions & 30 deletions embodichain/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,6 @@ def save_json(path: str, data):
json.dump(data, f, indent=4)


def save_json(path: str, data):
import json

with open(path, "w") as f:
json.dump(data, f, indent=4)


def load_json(path: str) -> Dict:
import json

Expand Down Expand Up @@ -455,14 +448,22 @@ def postprocess_small_regions(
min_area: int,
max_area: int,
) -> List[int]:
keep_idx = []
"""Filter masks based on area constraints.

Args:
masks: Array of binary masks or list of masks.
min_area: Minimum area threshold (exclusive - areas must be strictly greater).
max_area: Maximum area threshold (inclusive - areas can equal this value).

Returns:
List of indices for masks that meet the area constraints (min_area < area <= max_area).
"""
n = len(masks) if isinstance(masks, list) else masks.shape[0]
for i in range(n):
area = masks[i].astype(np.uint8).sum()
keep = area > min_area and area <= max_area
if keep:
keep_idx.append(i)
return keep_idx
# Use list comprehension for more efficient filtering
# Logic: area > min_area and area <= max_area (original behavior preserved)
return [
i for i in range(n) if min_area < masks[i].astype(np.uint8).sum() <= max_area
]


def mask_to_box(mask: np.ndarray) -> np.ndarray:
Expand All @@ -478,27 +479,13 @@ def mask_to_box(mask: np.ndarray) -> np.ndarray:
return bbox


def postprocess_small_regions(
masks: np.ndarray,
min_area: int,
max_area: int,
) -> List[int]:
keep_idx = []
n = len(masks) if isinstance(masks, list) else masks.shape[0]
for i in range(n):
area = masks[i].astype(np.uint8).sum()
keep = area > min_area and area <= max_area
if keep:
keep_idx.append(i)
return keep_idx


def remove_overlap_mask(
masks: List[np.ndarray], keep_inner_threshold: float = 0.5, eps: float = 1e-5
) -> List[int]:
keep_ids = []

areas = [mask.astype(np.uint8).sum() for mask in masks]
# Pre-compute areas once for efficiency
areas = np.array([mask.astype(np.uint8).sum() for mask in masks])

for i, maskA in enumerate(masks):
keep = True
Expand Down