Skip to content

Commit 464b45a

Browse files
🔧 Fix auto batch size handling with tiling and improve error logging (#4233)
* 🔧 Fix auto batch size handling with tiling and improve error logging * Update CHANGELOG to include link for auto batch size fix with tiling * fix linter --------- Co-authored-by: Prokofiev Kirill <[email protected]>
1 parent bccea76 commit 464b45a

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ All notable changes to this project will be documented in this file.
1010

1111
### Bug fixes
1212

13+
14+
- Fix auto batch size with tiling
15+
(<https://github.com/openvinotoolkit/training_extensions/pull/4233>)
1316
- Fix exportable code for tiling
1417
(<https://github.com/openvinotoolkit/training_extensions/pull/4234>)
1518
- Don't filter empty label from kp arrow

src/otx/core/utils/tile_merge.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
from abc import abstractmethod
99
from collections import defaultdict
10+
from typing import Callable
1011

1112
import cv2
1213
import numpy as np
1314
import torch
15+
from packaging import version
1416
from torchvision import tv_tensors
1517
from torchvision.ops import batched_nms
1618

@@ -21,6 +23,38 @@
2123
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity, InstanceSegPredEntity
2224
from otx.core.data.entity.segmentation import SegBatchPredEntity, SegPredEntity
2325

26+
# Maximum number of elements 2**31 -1
27+
MAX_ELEMENTS: int = np.iinfo(np.int32).max
28+
29+
30+
# NOTE: RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements,
31+
# See https://github.com/pytorch/pytorch/issues/51871
32+
int_max_check_condition: Callable[[torch.Tensor], bool] = (
33+
lambda tile_masks: version.parse(torch.__version__) < version.parse("2.6")
34+
and torch.numel(tile_masks) > MAX_ELEMENTS
35+
)
36+
37+
38+
def keep_chunkify(tensor: torch.Tensor, max_element: int = MAX_ELEMENTS) -> torch.Tensor:
39+
"""Splits tensor into chunks and processes each chunk separately.
40+
41+
Args:
42+
tensor (torch.Tensor): Input tensor of shape (B, H, W).
43+
44+
Returns:
45+
torch.Tensor: Boolean mask of shape (B,) indicating nonzero sum.
46+
"""
47+
_, h, w = tensor.shape
48+
max_batch_size = int(max_element) // (h * w)
49+
chunk_size = max(1, min(max_batch_size, tensor.shape[0]))
50+
51+
keep_indices = []
52+
for i in range(0, tensor.shape[0], chunk_size):
53+
chunk = tensor[i : i + chunk_size]
54+
keep_indices.append(chunk.sum(dim=(1, 2)) > 0) # Process chunk
55+
56+
return torch.cat(keep_indices, dim=0)
57+
2458

2559
class TileMerge:
2660
"""Base class for tile merge.
@@ -332,7 +366,10 @@ def merge(
332366
feature_vectors,
333367
strict=True,
334368
):
335-
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
369+
if int_max_check_condition(tile_masks):
370+
keep_indices = keep_chunkify(tile_masks)
371+
else:
372+
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
336373
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
337374
_bboxes = tile_bboxes[keep_indices]
338375
_labels = tile_labels[keep_indices]

src/otx/engine/adaptive_bs/bs_search_algo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,13 @@ def auto_decrease_batch_size(self) -> int:
113113

114114
if available_bs == 0:
115115
if oom:
116-
msg = "Current device can't train model even with 2."
117-
raise RuntimeError(msg)
116+
logger.warning(
117+
"The auto batch size algorithm attempted to use a batch size of 2 but still "
118+
"encountered a CUDA OOM error. OTX will proceed with training at batch size 2; "
119+
"however, you will likely encounter a CUDA OOM error once training starts. "
120+
"If the issue persists, please report it accordingly.",
121+
)
122+
return 2
118123
logger.warning(
119124
"Even with a batch size of 2, most of the memory is used, "
120125
"which could cause the training to fail midway.",

tests/unit/engine/adaptive_bs/test_bs_search_algo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def test_find_max_usable_bs_gpu_memory_too_small(self):
106106
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
107107

108108
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
109-
with pytest.raises(RuntimeError):
110-
bs_search_algo.auto_decrease_batch_size()
109+
assert bs_search_algo.auto_decrease_batch_size() == 2
111110

112111
def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self):
113112
"""Batch size 2 doesn't make oom but use most of memory."""
@@ -140,8 +139,7 @@ def test_find_big_enough_batch_size_gpu_memory_too_small(self):
140139
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
141140

142141
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
143-
with pytest.raises(RuntimeError):
144-
bs_search_algo.find_big_enough_batch_size()
142+
assert bs_search_algo.find_big_enough_batch_size() == 2
145143

146144
def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self):
147145
"""Batch size 2 doesn't make oom but use most of memory."""

0 commit comments

Comments
 (0)