1+ from typing import Optional , List
2+
13import pytest
4+ import torch
25
36from tests .test_suite .e2e_test_system import e2e_pytest_unit
47from otx .algorithms .common .adapters .torch .utils import BsSearchAlgo
@@ -11,6 +14,8 @@ class TestBsSearchAlgo:
1114 def setup_test (self , mocker ):
1215 self .mock_torch = mocker .patch .object (bs_search_algo , "torch" )
1316 self .mock_torch .cuda .mem_get_info .return_value = (1 , 10000 )
17+ self .mock_dist = mocker .patch .object (bs_search_algo , "dist" )
18+ self .mock_dist .is_initialized .return_value = False
1419
1520 def test_init (self , mocker ):
1621 BsSearchAlgo (mocker .MagicMock (), 4 , 10 )
@@ -35,11 +40,122 @@ def mock_train_func(batch_size):
3540 else :
3641 mem_usage = 8500 * batch_size / max_runnable_bs
3742
38- self .mock_torch .cuda .max_memory_allocated .return_value = mem_usage
43+ self .mock_torch .cuda .max_memory_reserved .return_value = mem_usage
3944 return mem_usage
4045
4146 return mock_train_func
4247
48+ def test_try_batch_size (self ):
49+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 10000 , max_runnable_bs = 80 )
50+ bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
51+ batch_size = 40
52+
53+ cuda_oom , max_memory_reserved = bs_search_algo ._try_batch_size (batch_size )
54+
55+ assert cuda_oom is False
56+ assert max_memory_reserved == mock_train_func (batch_size )
57+ self .mock_torch .cuda .reset_max_memory_cached .assert_called ()
58+ self .mock_torch .cuda .empty_cache .assert_called ()
59+
60+ def test_try_batch_size_cuda_oom (self ):
61+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 100 , max_runnable_bs = 80 )
62+ bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
63+ batch_size = 200
64+
65+ cuda_oom , _ = bs_search_algo ._try_batch_size (batch_size )
66+
67+ assert cuda_oom is True
68+ self .mock_torch .cuda .reset_max_memory_cached .assert_called ()
69+ self .mock_torch .cuda .empty_cache .assert_called ()
70+
71+ def _prepare_dist_test (self , broadcast_val : torch .Tensor , gather_val : Optional [List [torch .Tensor ]] = None ):
72+ self .mock_dist .is_initialized .return_value = True
73+
74+ # mocking torch.distributed.broadcast
75+ def mock_broadcast (tensor : torch .Tensor , src : int ):
76+ tensor .copy_ (broadcast_val )
77+
78+ self .mock_dist .broadcast .side_effect = mock_broadcast
79+
80+ # mocking torch.distributed.gather if gather_val is given
81+ def mock_gather (tensor : torch .Tensor , gather_list : Optional [List [torch .Tensor ]] = None , dst : int = 0 ):
82+ for i in range (len (gather_list )):
83+ gather_list [i ].copy_ (gather_val [i ])
84+
85+ if gather_val is not None :
86+ self .mock_dist .gather .side_effect = mock_gather
87+
88+ # revert some of torch function
89+ def mock_tensor_cuda (self , * args , ** kwargs ):
90+ return self
91+
92+ torch .Tensor .cuda = mock_tensor_cuda
93+ self .mock_torch .tensor = torch .tensor
94+ self .mock_torch .int64 = torch .int64
95+ self .mock_torch .max = torch .max
96+ self .mock_torch .any = torch .any
97+ self .mock_torch .stack = torch .stack
98+ self .mock_torch .empty = torch .empty
99+
100+ def test_try_batch_size_distributed_not_rank_0 (self ):
101+ self .mock_dist .get_rank .return_value = 1
102+ broadcasted_cuda_oom = False
103+ broadcasted_max_memory_reserved = 4000
104+ self ._prepare_dist_test (
105+ broadcast_val = torch .tensor ([broadcasted_cuda_oom , broadcasted_max_memory_reserved ], dtype = torch .int64 )
106+ )
107+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 10000 , max_runnable_bs = 80 )
108+ batch_size = 40
109+ bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
110+ w1_max_memory_reserved = mock_train_func (batch_size )
111+
112+ cuda_oom , max_memory_reserved = bs_search_algo ._try_batch_size (batch_size )
113+
114+ # check dist.gather is called and get [cuda_oom, maxmemory_reserved] as arguments.
115+ self .mock_dist .gather .assert_called_once ()
116+ assert self .mock_dist .gather .call_args .args [0 ][0 ].item () == False
117+ assert self .mock_dist .gather .call_args .args [0 ][1 ].item () == w1_max_memory_reserved
118+ assert self .mock_dist .gather .call_args .kwargs ["dst" ] == 0
119+ # check dist.broadcast is called
120+ self .mock_dist .broadcast .assert_called_once ()
121+ assert self .mock_dist .broadcast .call_args .kwargs ["src" ] == 0
122+ # check broadcased values are returned
123+ assert cuda_oom is broadcasted_cuda_oom
124+ assert max_memory_reserved == broadcasted_max_memory_reserved
125+
126+ def test_try_batch_size_distributed_rank_0 (self ):
127+ self .mock_dist .get_rank .return_value = 0
128+ self .mock_dist .get_world_size .return_value = 2
129+ self ._prepare_dist_test (
130+ broadcast_val = torch .tensor ([True , 4000 ], dtype = torch .int64 ),
131+ gather_val = [
132+ torch .tensor ([False , 3000 ], dtype = torch .int64 ),
133+ torch .tensor ([True , 4000 ], dtype = torch .int64 ),
134+ ],
135+ )
136+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 10000 , max_runnable_bs = 80 )
137+ batch_size = 40
138+ bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
139+ w0_max_memory_reserved = mock_train_func (batch_size )
140+
141+ cuda_oom , max_memory_reserved = bs_search_algo ._try_batch_size (batch_size )
142+
143+ # check dist.gather is called and get [cuda_oom, max_memory_reserved] as arguments.
144+ self .mock_dist .gather .assert_called_once ()
145+ assert self .mock_dist .gather .call_args .args [0 ][0 ].item () == False
146+ assert self .mock_dist .gather .call_args .args [0 ][1 ].item () == w0_max_memory_reserved
147+ assert self .mock_dist .gather .call_args .kwargs ["dst" ] == 0
148+ # check if any process get cuda oom then set cuda_oom to True and
149+ # set max_memory_reserved to maximum value of processes'
150+ self .mock_dist .broadcast .assert_called_once ()
151+ self .mock_dist .broadcast .assert_called_once ()
152+ assert self .mock_dist .broadcast .call_args .kwargs ["src" ] == 0
153+ assert self .mock_dist .broadcast .call_args .args [0 ][0 ].item () == True
154+ assert self .mock_dist .broadcast .call_args .args [0 ][1 ].item () == 4000
155+ # check proper values are returned
156+ assert cuda_oom is True
157+ assert max_memory_reserved == 4000
158+
43159 def test_auto_decrease_batch_size (self ):
44160 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 10000 , max_runnable_bs = 80 )
45161
@@ -91,7 +207,7 @@ def mock_train_func(batch_size):
91207 mem_usage = 9000
92208 else :
93209 mem_usage = 1000
94- self .mock_torch .cuda .max_memory_allocated .return_value = mem_usage
210+ self .mock_torch .cuda .max_memory_reserved .return_value = mem_usage
95211 return mem_usage
96212
97213 bs_search_algo = BsSearchAlgo (mock_train_func , 64 , 1000 )
@@ -108,7 +224,7 @@ def mock_train_func(batch_size):
108224 mem_usage = 9000
109225 else :
110226 mem_usage = 1000 + batch_size / 1000
111- self .mock_torch .cuda .max_memory_allocated .return_value = mem_usage
227+ self .mock_torch .cuda .max_memory_reserved .return_value = mem_usage
112228 return mem_usage
113229
114230 bs_search_algo = BsSearchAlgo (mock_train_func , 64 , 1000 )
0 commit comments