Skip to content

Commit 1e45363

Browse files
authored
Merge pull request #18 from KushnirDmytro/optimize_memory_dist_add_symmetric
Memory optimized dists_add_symmetric
2 parents f526e29 + 539672c commit 1e45363

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

cosypose/lib3d/distances.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@ def dists_add(TXO_pred, TXO_gt, points):
88
dists = TXO_gt_points - TXO_pred_points
99
return dists
1010

11-
1211
def dists_add_symmetric(TXO_pred, TXO_gt, points):
1312
TXO_pred_points = transform_pts(TXO_pred, points)
1413
TXO_gt_points = transform_pts(TXO_gt, points)
15-
dists = TXO_gt_points.unsqueeze(1) - TXO_pred_points.unsqueeze(2)
16-
dists_norm_squared = (dists ** 2).sum(dim=-1)
17-
assign = dists_norm_squared.argmin(dim=1)
18-
ids_row = torch.arange(dists.shape[0]).unsqueeze(1).repeat(1, dists.shape[1])
19-
ids_col = torch.arange(dists.shape[1]).unsqueeze(0).repeat(dists.shape[0], 1)
20-
dists = dists[ids_row, assign, ids_col]
21-
return dists
14+
distances = torch.cdist(TXO_gt_points, TXO_pred_points,
15+
p=2, compute_mode='donot_use_mm_for_euclid_dist')
16+
closest_points_idx = torch.argmin(distances, dim=2).squeeze()
17+
TXO_pred_closest_to_gt = torch.index_select(TXO_pred_points, 1, closest_points_idx)
18+
min_translations = TXO_gt_points - TXO_pred_closest_to_gt
19+
return min_translations

0 commit comments

Comments
 (0)