|
11 | 11 |
|
12 | 12 | from ..box_regression import Box2BoxTransformRotated |
13 | 13 | from .build import PROPOSAL_GENERATOR_REGISTRY |
| 14 | +from .proposal_utils import _is_tracing |
14 | 15 | from .rpn import RPN |
15 | 16 |
|
16 | 17 | logger = logging.getLogger(__name__) |
@@ -67,7 +68,10 @@ def find_top_rrpn_proposals( |
67 | 68 | itertools.count(), proposals, pred_objectness_logits |
68 | 69 | ): |
69 | 70 | Hi_Wi_A = logits_i.shape[1] |
70 | | - num_proposals_i = min(pre_nms_topk, Hi_Wi_A) |
| 71 | + if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing |
| 72 | + num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk) |
| 73 | + else: |
| 74 | + num_proposals_i = min(Hi_Wi_A, pre_nms_topk) |
71 | 75 |
|
72 | 76 | # sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812) |
73 | 77 | # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) |
@@ -101,7 +105,7 @@ def find_top_rrpn_proposals( |
101 | 105 | # filter empty boxes |
102 | 106 | keep = boxes.nonempty(threshold=min_box_size) |
103 | 107 | lvl = level_ids |
104 | | - if keep.sum().item() != len(boxes): |
| 108 | + if _is_tracing() or keep.sum().item() != len(boxes): |
105 | 109 | boxes, scores_per_img, lvl = (boxes[keep], scores_per_img[keep], level_ids[keep]) |
106 | 110 |
|
107 | 111 | keep = batched_nms_rotated(boxes.tensor, scores_per_img, lvl, nms_thresh) |
|
0 commit comments