|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved |
| 2 | + |
| 3 | +from collections import defaultdict |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | +from sam3.perflib.masks_ops import mask_iou |
| 8 | +from scipy.optimize import linear_sum_assignment |
| 9 | + |
| 10 | + |
| 11 | +def associate_det_trk( |
| 12 | + det_masks, |
| 13 | + track_masks, |
| 14 | + iou_threshold=0.5, |
| 15 | + iou_threshold_trk=0.5, |
| 16 | + det_scores=None, |
| 17 | + new_det_thresh=0.0, |
| 18 | +): |
| 19 | + """ |
| 20 | + Optimized implementation of detection <-> track association that minimizes DtoH syncs. |
| 21 | +
|
| 22 | + Args: |
| 23 | + det_masks: (N, H, W) tensor of predicted masks |
| 24 | + track_masks: (M, H, W) tensor of track masks |
| 25 | +
|
| 26 | + Returns: |
| 27 | + new_det_indices: list of indices in det_masks considered 'new' |
| 28 | + unmatched_trk_indices: list of indices in track_masks considered 'unmatched' |
| 29 | + """ |
| 30 | + with torch.autograd.profiler.record_function("perflib: associate_det_trk"): |
| 31 | + assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor" |
| 32 | + assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor" |
| 33 | + if det_masks.size(0) == 0 or track_masks.size(0) == 0: |
| 34 | + return list(range(det_masks.size(0))), [], {}, {} # all detections are new |
| 35 | + |
| 36 | + if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]): |
| 37 | + # resize to the smaller size to save GPU memory |
| 38 | + if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]): |
| 39 | + track_masks = ( |
| 40 | + F.interpolate( |
| 41 | + track_masks.unsqueeze(1).float(), |
| 42 | + size=det_masks.shape[-2:], |
| 43 | + mode="bilinear", |
| 44 | + align_corners=False, |
| 45 | + ).squeeze(1) |
| 46 | + > 0 |
| 47 | + ) |
| 48 | + else: |
| 49 | + # resize detections to track size |
| 50 | + det_masks = ( |
| 51 | + F.interpolate( |
| 52 | + det_masks.unsqueeze(1).float(), |
| 53 | + size=track_masks.shape[-2:], |
| 54 | + mode="bilinear", |
| 55 | + align_corners=False, |
| 56 | + ).squeeze(1) |
| 57 | + > 0 |
| 58 | + ) |
| 59 | + |
| 60 | + det_masks = det_masks > 0 |
| 61 | + track_masks = track_masks > 0 |
| 62 | + |
| 63 | + iou = mask_iou(det_masks, track_masks) # (N, M) |
| 64 | + igeit = iou >= iou_threshold |
| 65 | + igeit_any_dim_1 = igeit.any(dim=1) |
| 66 | + igeit_trk = iou >= iou_threshold_trk |
| 67 | + |
| 68 | + iou_list = iou.cpu().numpy().tolist() |
| 69 | + igeit_list = igeit.cpu().numpy().tolist() |
| 70 | + igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist() |
| 71 | + igeit_trk_list = igeit_trk.cpu().numpy().tolist() |
| 72 | + |
| 73 | + det_scores_list = ( |
| 74 | + det_scores |
| 75 | + if det_scores is None |
| 76 | + else det_scores.cpu().float().numpy().tolist() |
| 77 | + ) |
| 78 | + |
| 79 | + # Hungarian matching for tracks (one-to-one: each track matches at most one detection) |
| 80 | + # For detections: allow many tracks to match to the same detection (many-to-one) |
| 81 | + |
| 82 | + # If either is empty, return all detections as new |
| 83 | + if det_masks.size(0) == 0 or track_masks.size(0) == 0: |
| 84 | + return list(range(det_masks.size(0))), [], {} |
| 85 | + |
| 86 | + # Hungarian matching: maximize IoU for tracks |
| 87 | + cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost |
| 88 | + row_ind, col_ind = linear_sum_assignment(cost_matrix) |
| 89 | + |
| 90 | + def branchy_hungarian_better_uses_the_cpu( |
| 91 | + cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks |
| 92 | + ): |
| 93 | + matched_trk = set() |
| 94 | + matched_det = set() |
| 95 | + matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask |
| 96 | + for d, t in zip(row_ind, col_ind): |
| 97 | + matched_det_scores[t] = [ |
| 98 | + det_scores_list[d], |
| 99 | + det_scores_list[d] * iou_list[d][t], |
| 100 | + ] |
| 101 | + if igeit_trk_list[d][t]: |
| 102 | + matched_trk.add(t) |
| 103 | + matched_det.add(d) |
| 104 | + |
| 105 | + # Tracks not matched by Hungarian assignment above threshold are unmatched |
| 106 | + unmatched_trk_indices = [ |
| 107 | + t for t in range(track_masks.size(0)) if t not in matched_trk |
| 108 | + ] |
| 109 | + |
| 110 | + # For detections: allow many tracks to match to the same detection (many-to-one) |
| 111 | + # So, a detection is 'new' if it does not match any track above threshold |
| 112 | + assert track_masks.size(0) == igeit.size( |
| 113 | + 1 |
| 114 | + ) # Needed for loop optimizaiton below |
| 115 | + new_det_indices = [] |
| 116 | + for d in range(det_masks.size(0)): |
| 117 | + if not igeit_any_dim_1_list[d]: |
| 118 | + if det_scores is not None and det_scores[d] >= new_det_thresh: |
| 119 | + new_det_indices.append(d) |
| 120 | + |
| 121 | + # for each detection, which tracks it matched to (above threshold) |
| 122 | + det_to_matched_trk = defaultdict(list) |
| 123 | + for d in range(det_masks.size(0)): |
| 124 | + for t in range(track_masks.size(0)): |
| 125 | + if igeit_list[d][t]: |
| 126 | + det_to_matched_trk[d].append(t) |
| 127 | + |
| 128 | + return ( |
| 129 | + new_det_indices, |
| 130 | + unmatched_trk_indices, |
| 131 | + det_to_matched_trk, |
| 132 | + matched_det_scores, |
| 133 | + ) |
| 134 | + |
| 135 | + return (branchy_hungarian_better_uses_the_cpu)( |
| 136 | + cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks |
| 137 | + ) |
0 commit comments