Skip to content

Commit 3800e83

Browse files
Copilotyuecideng
andauthored
Fix performance issues: remove duplicates, fix broken deprecation warnings, optimize loops (#19)
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: yuecideng <[email protected]>
1 parent ea28c20 commit 3800e83

File tree

4 files changed

+46
-92
lines changed

4 files changed

+46
-92
lines changed

embodichain/lab/sim/solvers/differential_solver.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ----------------------------------------------------------------------------
1616

1717
import torch
18-
from copy import deepcopy
1918
from typing import Optional, Union, Tuple, Any, Literal, TYPE_CHECKING
2019
from scipy.spatial.transform import Rotation
2120

@@ -245,11 +244,13 @@ def get_ik(
245244
current_xpos = self.get_fk(qpos_seed, to_matrix=True)
246245

247246
# Transform target_xpos by TCP
247+
# Note: torch.as_tensor does not modify the input, so deepcopy is unnecessary
248248
tcp_xpos = torch.as_tensor(
249-
deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32
249+
self.tcp_xpos, device=self.device, dtype=torch.float32
250250
)
251-
current_xpos = current_xpos @ torch.inverse(tcp_xpos)
252-
compute_xpos = target_xpos @ torch.inverse(tcp_xpos)
251+
tcp_xpos_inv = torch.inverse(tcp_xpos)
252+
current_xpos = current_xpos @ tcp_xpos_inv
253+
compute_xpos = target_xpos @ tcp_xpos_inv
253254

254255
# Ensure compute_xpos is a batch of matrices
255256
if current_xpos.dim() == 2 and current_xpos.shape == (4, 4):

embodichain/toolkits/graspkit/pg_grasp/antipodal.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,11 @@ def get_dis_arr(self, others) -> np.ndarray:
187187
Returns:
188188
np.ndarray: distance array
189189
"""
190-
other_num = len(others)
191-
other_a = np.empty(shape=(other_num, 3), dtype=float)
192-
other_b = np.empty(shape=(other_num, 3), dtype=float)
193-
for i in range(other_num):
194-
other_a[i] = others[i].point_a
195-
other_b[i] = others[i].point_b
190+
if not others:
191+
return np.array([], dtype=float)
192+
# Vectorized extraction of points using list comprehension and np.array
193+
other_a = np.array([o.point_a for o in others], dtype=float)
194+
other_b = np.array([o.point_b for o in others], dtype=float)
196195
aa_dis = np.linalg.norm(other_a - self.point_a, axis=1)
197196
ab_dis = np.linalg.norm(other_a - self.point_b, axis=1)
198197
ba_dis = np.linalg.norm(other_b - self.point_a, axis=1)
@@ -382,22 +381,26 @@ def _generate_cache(
382381
# self.antipodal_visual(nms_antipodal_list)
383382
grasp_num = grasp_poses.shape[0]
384383
logger.log_debug(f"Write {grasp_num} poses to pickle file {cache_file}.")
385-
grasp_list = [None for i in range(grasp_num)]
386-
for i in range(grasp_num):
387-
grasp_list[i] = AntipodalGrasp(grasp_poses[i], open_length[i], score[i])
384+
# Use list comprehension for efficient list construction
385+
grasp_list = [
386+
AntipodalGrasp(grasp_poses[i], open_length[i], score[i])
387+
for i in range(grasp_num)
388+
]
388389
return grasp_list
389390

390391
def _load_cache(self, cache_file: str):
391392
data_dict = pickle.load(open(cache_file, "rb"))
392393
grasp_num = data_dict["grasp_poses"].shape[0]
393394
logger.log_debug(f"Load {grasp_num} poses from pickle file {cache_file}.")
394-
grasp_list = [None for i in range(grasp_num)]
395-
for i in range(grasp_num):
396-
grasp_list[i] = AntipodalGrasp(
395+
# Use list comprehension for efficient list construction
396+
grasp_list = [
397+
AntipodalGrasp(
397398
data_dict["grasp_poses"][i],
398399
data_dict["open_length"][i],
399400
data_dict["score"][i],
400401
)
402+
for i in range(grasp_num)
403+
]
401404
return grasp_list
402405

403406
def _get_pc_size(self, vertices):
@@ -521,13 +524,11 @@ def select_grasp(
521524
"""
522525
grasp_num = len(self._grasp_list)
523526
all_idx = np.arange(grasp_num)
524-
grasp_poses = np.empty(shape=(grasp_num, 4, 4), dtype=float)
525-
scores = np.empty(shape=(grasp_num,), dtype=float)
526-
position = grasp_poses[:, :3, 3]
527527

528-
for i in range(grasp_num):
529-
grasp_poses[i] = self._grasp_list[i].pose
530-
scores[i] = self._grasp_list[i].score
528+
# Vectorized extraction of poses and scores using list comprehension
529+
grasp_poses = np.array([g.pose for g in self._grasp_list], dtype=float)
530+
scores = np.array([g.score for g in self._grasp_list], dtype=float)
531+
position = grasp_poses[:, :3, 3]
531532

532533
# mask acoording to table up direction
533534
grasp_z = grasp_poses[:, :3, 2]
@@ -557,10 +558,8 @@ def select_grasp(
557558
best_valid_idx = sort_valid_idx[:result_num]
558559
best_idx = valid_id[best_valid_idx]
559560

560-
result_grasp_list = []
561-
for idx in best_idx:
562-
result_grasp_list.append(self._grasp_list[idx])
563-
return result_grasp_list
561+
# Use list comprehension for faster list construction
562+
return [self._grasp_list[idx] for idx in best_idx]
564563

565564
def _antipodal_nms(
566565
self, antipodal_list: List[Antipodal], nms_ratio: float = 0.02

embodichain/utils/math.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import math
21+
import warnings
2122
import torch
2223
import numpy as np
2324
import torch.nn.functional
@@ -394,10 +395,8 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
394395
Reference:
395396
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99
396397
"""
397-
ret = torch.zeros_like(x)
398-
positive_mask = x > 0
399-
ret[positive_mask] = torch.sqrt(x[positive_mask])
400-
return ret
398+
# Use torch.where for vectorized operation instead of indexed assignment
399+
return torch.where(x > 0, torch.sqrt(x), torch.zeros_like(x))
401400

402401

403402
@torch.jit.script
@@ -508,38 +507,6 @@ def trans_matrix_to_xyz_quat(matrix: torch.Tensor) -> torch.Tensor:
508507
return vec
509508

510509

511-
@torch.jit.script
512-
def quat_from_euler_xyz(
513-
roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor
514-
) -> torch.Tensor:
515-
"""Convert rotations given as Euler angles in radians to Quaternions.
516-
517-
Note:
518-
The euler angles are assumed in XYZ convention.
519-
520-
Args:
521-
roll: Rotation around x-axis (in radians). Shape is (N,).
522-
pitch: Rotation around y-axis (in radians). Shape is (N,).
523-
yaw: Rotation around z-axis (in radians). Shape is (N,).
524-
525-
Returns:
526-
The quaternion in (w, x, y, z). Shape is (N, 4).
527-
"""
528-
cy = torch.cos(yaw * 0.5)
529-
sy = torch.sin(yaw * 0.5)
530-
cr = torch.cos(roll * 0.5)
531-
sr = torch.sin(roll * 0.5)
532-
cp = torch.cos(pitch * 0.5)
533-
sp = torch.sin(pitch * 0.5)
534-
# compute quaternion
535-
qw = cy * cr * cp + sy * sr * sp
536-
qx = cy * sr * cp - sy * cr * sp
537-
qy = cy * cr * sp + sy * sr * cp
538-
qz = sy * cr * cp - cy * sr * sp
539-
540-
return torch.stack([qw, qx, qy, qz], dim=-1)
541-
542-
543510
def _axis_angle_rotation(
544511
axis: Literal["X", "Y", "Z"], angle: torch.Tensor
545512
) -> torch.Tensor:

embodichain/utils/utility.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,6 @@ def save_json(path: str, data):
352352
json.dump(data, f, indent=4)
353353

354354

355-
def save_json(path: str, data):
356-
import json
357-
358-
with open(path, "w") as f:
359-
json.dump(data, f, indent=4)
360-
361-
362355
def load_json(path: str) -> Dict:
363356
import json
364357

@@ -455,14 +448,22 @@ def postprocess_small_regions(
455448
min_area: int,
456449
max_area: int,
457450
) -> List[int]:
458-
keep_idx = []
451+
"""Filter masks based on area constraints.
452+
453+
Args:
454+
masks: Array of binary masks or list of masks.
455+
min_area: Minimum area threshold (exclusive - areas must be strictly greater).
456+
max_area: Maximum area threshold (inclusive - areas can equal this value).
457+
458+
Returns:
459+
List of indices for masks that meet the area constraints (min_area < area <= max_area).
460+
"""
459461
n = len(masks) if isinstance(masks, list) else masks.shape[0]
460-
for i in range(n):
461-
area = masks[i].astype(np.uint8).sum()
462-
keep = area > min_area and area <= max_area
463-
if keep:
464-
keep_idx.append(i)
465-
return keep_idx
462+
# Use list comprehension for more efficient filtering
463+
# Logic: area > min_area and area <= max_area (original behavior preserved)
464+
return [
465+
i for i in range(n) if min_area < masks[i].astype(np.uint8).sum() <= max_area
466+
]
466467

467468

468469
def mask_to_box(mask: np.ndarray) -> np.ndarray:
@@ -478,27 +479,13 @@ def mask_to_box(mask: np.ndarray) -> np.ndarray:
478479
return bbox
479480

480481

481-
def postprocess_small_regions(
482-
masks: np.ndarray,
483-
min_area: int,
484-
max_area: int,
485-
) -> List[int]:
486-
keep_idx = []
487-
n = len(masks) if isinstance(masks, list) else masks.shape[0]
488-
for i in range(n):
489-
area = masks[i].astype(np.uint8).sum()
490-
keep = area > min_area and area <= max_area
491-
if keep:
492-
keep_idx.append(i)
493-
return keep_idx
494-
495-
496482
def remove_overlap_mask(
497483
masks: List[np.ndarray], keep_inner_threshold: float = 0.5, eps: float = 1e-5
498484
) -> List[int]:
499485
keep_ids = []
500486

501-
areas = [mask.astype(np.uint8).sum() for mask in masks]
487+
# Pre-compute areas once for efficiency
488+
areas = np.array([mask.astype(np.uint8).sum() for mask in masks])
502489

503490
for i, maskA in enumerate(masks):
504491
keep = True

0 commit comments

Comments
 (0)