|
7 | 7 |
|
8 | 8 | from abc import abstractmethod |
9 | 9 | from collections import defaultdict |
| 10 | +from typing import Callable |
10 | 11 |
|
11 | 12 | import cv2 |
12 | 13 | import numpy as np |
13 | 14 | import torch |
| 15 | +from packaging import version |
14 | 16 | from torchvision import tv_tensors |
15 | 17 | from torchvision.ops import batched_nms |
16 | 18 |
|
|
21 | 23 | from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity |
22 | 24 | from otx.core.data.entity.segmentation import SegBatchPredEntity, SegPredEntity |
23 | 25 |
|
| 26 | +# Maximum number of elements 2**31 -1 |
| 27 | +MAX_ELEMENTS: int = np.iinfo(np.int32).max |
| 28 | + |
| 29 | + |
| 30 | +# NOTE: RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements, |
| 31 | +# See https://github.com/pytorch/pytorch/issues/51871 |
| 32 | +int_max_check_condition: Callable[[torch.Tensor], bool] = ( |
| 33 | + lambda tile_masks: version.parse(torch.__version__) < version.parse("2.6") |
| 34 | + and torch.numel(tile_masks) > MAX_ELEMENTS |
| 35 | +) |
| 36 | + |
| 37 | + |
| 38 | +def keep_chunkify(tensor: torch.Tensor, max_element: int = MAX_ELEMENTS) -> torch.Tensor: |
| 39 | + """Splits tensor into chunks and processes each chunk separately. |
| 40 | +
|
| 41 | + Args: |
| 42 | + tensor (torch.Tensor): Input tensor of shape (B, H, W). |
| 43 | +
|
| 44 | + Returns: |
| 45 | + torch.Tensor: Boolean mask of shape (B,) indicating nonzero sum. |
| 46 | + """ |
| 47 | + _, h, w = tensor.shape |
| 48 | + max_batch_size = int(max_element) // (h * w) |
| 49 | + chunk_size = max(1, min(max_batch_size, tensor.shape[0])) |
| 50 | + |
| 51 | + keep_indices = [] |
| 52 | + for i in range(0, tensor.shape[0], chunk_size): |
| 53 | + chunk = tensor[i : i + chunk_size] |
| 54 | + keep_indices.append(chunk.sum(dim=(1, 2)) > 0) # Process chunk |
| 55 | + |
| 56 | + return torch.cat(keep_indices, dim=0) |
| 57 | + |
24 | 58 |
|
25 | 59 | class TileMerge: |
26 | 60 | """Base class for tile merge. |
@@ -332,7 +366,10 @@ def merge( |
332 | 366 | feature_vectors, |
333 | 367 | strict=True, |
334 | 368 | ): |
335 | | - keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0 |
| 369 | + if int_max_check_condition(tile_masks): |
| 370 | + keep_indices = keep_chunkify(tile_masks) |
| 371 | + else: |
| 372 | + keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0 |
336 | 373 | keep_indices = keep_indices.nonzero(as_tuple=True)[0] |
337 | 374 | _bboxes = tile_bboxes[keep_indices] |
338 | 375 | _labels = tile_labels[keep_indices] |
|
0 commit comments