Skip to content

Commit 419a0f2

Browse files
authored
Fix bug that auto batch size doesn't consider distributed training (#2533)
* consider distributed training while searching batch size * update unit test * reveret gpu memory upper bound * fix typo * change allocated to reserved * add unit test for distributed training * align with pre-commit
1 parent b0eac19 commit 419a0f2

File tree

2 files changed

+151
-10
lines changed

2 files changed

+151
-10
lines changed

src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Callable, Dict, Tuple
77

88
import torch
9+
import torch.distributed as dist
910

1011
from otx.algorithms.common.utils.logger import get_logger
1112

@@ -40,7 +41,7 @@ def __init__(self, train_func: Callable[[int], None], default_bs: int, max_bs: i
4041

4142
def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
4243
cuda_oom = False
43-
torch.cuda.reset_max_memory_allocated(device=None)
44+
torch.cuda.reset_max_memory_cached(device=None)
4445
torch.cuda.empty_cache()
4546

4647
try:
@@ -51,18 +52,42 @@ def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
5152
else:
5253
raise e
5354

54-
max_memory_allocated = torch.cuda.max_memory_allocated(device=None)
55+
max_memory_reserved = torch.cuda.max_memory_reserved(device=None)
56+
57+
if dist.is_initialized(): # Aggregate all results and broadcast to all processes
58+
rank = dist.get_rank()
59+
try_result = torch.tensor([int(cuda_oom), max_memory_reserved], dtype=torch.int64).cuda()
60+
61+
if rank == 0:
62+
try_result_arr = [torch.empty(2, dtype=torch.int64).cuda() for _ in range(dist.get_world_size())]
63+
dist.gather(try_result, gather_list=try_result_arr, dst=0)
64+
else:
65+
dist.gather(try_result, dst=0)
66+
67+
if rank == 0:
68+
try_result_arr = torch.stack(try_result_arr)
69+
cuda_oom = torch.any(try_result_arr[:, 0]) # type: ignore
70+
max_memory_reserved = torch.max(try_result_arr[:, 1]) # type: ignore
71+
total_try_result = torch.tensor([cuda_oom, max_memory_reserved], dtype=torch.int64).cuda()
72+
else:
73+
total_try_result = torch.empty(2, dtype=torch.int64).cuda()
74+
75+
dist.broadcast(total_try_result, src=0)
76+
77+
cuda_oom = total_try_result[0].bool().item()
78+
max_memory_reserved = total_try_result[1].item()
79+
5580
if not cuda_oom:
5681
# Because heapq only supports min heap, use negatized batch size
57-
self._bs_try_history[bs] = max_memory_allocated
82+
self._bs_try_history[bs] = max_memory_reserved
5883

5984
logger.debug(
6085
f"Adapting Batch size => bs : {bs}, CUDA_OOM : {cuda_oom}, "
61-
f"GPU memory usage : {max_memory_allocated / self._total_mem}%"
86+
f"GPU memory usage : {max_memory_reserved / self._total_mem}%"
6287
)
6388
torch.cuda.empty_cache()
6489

65-
return cuda_oom, max_memory_allocated
90+
return cuda_oom, max_memory_reserved
6691

6792
@staticmethod
6893
def _get_even_center_val(val1: int, val2: int) -> int:
@@ -82,10 +107,10 @@ def auto_decrease_batch_size(self) -> int:
82107
lowest_unavailable_bs = self._default_bs + 2
83108

84109
while True:
85-
cuda_oom, max_memory_allocated = self._try_batch_size(current_bs)
110+
cuda_oom, max_memory_reserved = self._try_batch_size(current_bs)
86111

87112
# If GPU memory usage is too close to limit, CUDA OOM can be raised during training
88-
if cuda_oom or max_memory_allocated > self._mem_upper_bound:
113+
if cuda_oom or max_memory_reserved > self._mem_upper_bound:
89114
if current_bs < lowest_unavailable_bs:
90115
lowest_unavailable_bs = current_bs
91116
current_bs = self._get_even_center_val(current_bs, available_bs)

tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import Optional, List
2+
13
import pytest
4+
import torch
25

36
from tests.test_suite.e2e_test_system import e2e_pytest_unit
47
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
@@ -11,6 +14,8 @@ class TestBsSearchAlgo:
1114
def setup_test(self, mocker):
1215
self.mock_torch = mocker.patch.object(bs_search_algo, "torch")
1316
self.mock_torch.cuda.mem_get_info.return_value = (1, 10000)
17+
self.mock_dist = mocker.patch.object(bs_search_algo, "dist")
18+
self.mock_dist.is_initialized.return_value = False
1419

1520
def test_init(self, mocker):
1621
BsSearchAlgo(mocker.MagicMock(), 4, 10)
@@ -35,11 +40,122 @@ def mock_train_func(batch_size):
3540
else:
3641
mem_usage = 8500 * batch_size / max_runnable_bs
3742

38-
self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
43+
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
3944
return mem_usage
4045

4146
return mock_train_func
4247

48+
def test_try_batch_size(self):
49+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
50+
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
51+
batch_size = 40
52+
53+
cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)
54+
55+
assert cuda_oom is False
56+
assert max_memory_reserved == mock_train_func(batch_size)
57+
self.mock_torch.cuda.reset_max_memory_cached.assert_called()
58+
self.mock_torch.cuda.empty_cache.assert_called()
59+
60+
def test_try_batch_size_cuda_oom(self):
61+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=100, max_runnable_bs=80)
62+
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
63+
batch_size = 200
64+
65+
cuda_oom, _ = bs_search_algo._try_batch_size(batch_size)
66+
67+
assert cuda_oom is True
68+
self.mock_torch.cuda.reset_max_memory_cached.assert_called()
69+
self.mock_torch.cuda.empty_cache.assert_called()
70+
71+
def _prepare_dist_test(self, broadcast_val: torch.Tensor, gather_val: Optional[List[torch.Tensor]] = None):
72+
self.mock_dist.is_initialized.return_value = True
73+
74+
# mocking torch.distributed.broadcast
75+
def mock_broadcast(tensor: torch.Tensor, src: int):
76+
tensor.copy_(broadcast_val)
77+
78+
self.mock_dist.broadcast.side_effect = mock_broadcast
79+
80+
# mocking torch.distributed.gather if gather_val is given
81+
def mock_gather(tensor: torch.Tensor, gather_list: Optional[List[torch.Tensor]] = None, dst: int = 0):
82+
for i in range(len(gather_list)):
83+
gather_list[i].copy_(gather_val[i])
84+
85+
if gather_val is not None:
86+
self.mock_dist.gather.side_effect = mock_gather
87+
88+
# revert some of torch function
89+
def mock_tensor_cuda(self, *args, **kwargs):
90+
return self
91+
92+
torch.Tensor.cuda = mock_tensor_cuda
93+
self.mock_torch.tensor = torch.tensor
94+
self.mock_torch.int64 = torch.int64
95+
self.mock_torch.max = torch.max
96+
self.mock_torch.any = torch.any
97+
self.mock_torch.stack = torch.stack
98+
self.mock_torch.empty = torch.empty
99+
100+
def test_try_batch_size_distributed_not_rank_0(self):
101+
self.mock_dist.get_rank.return_value = 1
102+
broadcasted_cuda_oom = False
103+
broadcasted_max_memory_reserved = 4000
104+
self._prepare_dist_test(
105+
broadcast_val=torch.tensor([broadcasted_cuda_oom, broadcasted_max_memory_reserved], dtype=torch.int64)
106+
)
107+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
108+
batch_size = 40
109+
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
110+
w1_max_memory_reserved = mock_train_func(batch_size)
111+
112+
cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)
113+
114+
# check dist.gather is called and get [cuda_oom, maxmemory_reserved] as arguments.
115+
self.mock_dist.gather.assert_called_once()
116+
assert self.mock_dist.gather.call_args.args[0][0].item() == False
117+
assert self.mock_dist.gather.call_args.args[0][1].item() == w1_max_memory_reserved
118+
assert self.mock_dist.gather.call_args.kwargs["dst"] == 0
119+
# check dist.broadcast is called
120+
self.mock_dist.broadcast.assert_called_once()
121+
assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0
122+
# check broadcased values are returned
123+
assert cuda_oom is broadcasted_cuda_oom
124+
assert max_memory_reserved == broadcasted_max_memory_reserved
125+
126+
def test_try_batch_size_distributed_rank_0(self):
127+
self.mock_dist.get_rank.return_value = 0
128+
self.mock_dist.get_world_size.return_value = 2
129+
self._prepare_dist_test(
130+
broadcast_val=torch.tensor([True, 4000], dtype=torch.int64),
131+
gather_val=[
132+
torch.tensor([False, 3000], dtype=torch.int64),
133+
torch.tensor([True, 4000], dtype=torch.int64),
134+
],
135+
)
136+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
137+
batch_size = 40
138+
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
139+
w0_max_memory_reserved = mock_train_func(batch_size)
140+
141+
cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)
142+
143+
# check dist.gather is called and get [cuda_oom, max_memory_reserved] as arguments.
144+
self.mock_dist.gather.assert_called_once()
145+
assert self.mock_dist.gather.call_args.args[0][0].item() == False
146+
assert self.mock_dist.gather.call_args.args[0][1].item() == w0_max_memory_reserved
147+
assert self.mock_dist.gather.call_args.kwargs["dst"] == 0
148+
# check if any process get cuda oom then set cuda_oom to True and
149+
# set max_memory_reserved to maximum value of processes'
150+
self.mock_dist.broadcast.assert_called_once()
151+
self.mock_dist.broadcast.assert_called_once()
152+
assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0
153+
assert self.mock_dist.broadcast.call_args.args[0][0].item() == True
154+
assert self.mock_dist.broadcast.call_args.args[0][1].item() == 4000
155+
# check proper values are returned
156+
assert cuda_oom is True
157+
assert max_memory_reserved == 4000
158+
43159
def test_auto_decrease_batch_size(self):
44160
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
45161

@@ -91,7 +207,7 @@ def mock_train_func(batch_size):
91207
mem_usage = 9000
92208
else:
93209
mem_usage = 1000
94-
self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
210+
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
95211
return mem_usage
96212

97213
bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
@@ -108,7 +224,7 @@ def mock_train_func(batch_size):
108224
mem_usage = 9000
109225
else:
110226
mem_usage = 1000 + batch_size / 1000
111-
self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
227+
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
112228
return mem_usage
113229

114230
bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)

0 commit comments

Comments
 (0)