@@ -32,6 +32,8 @@ def _scale_batch_size(
32
32
init_val : int = 2 ,
33
33
max_trials : int = 25 ,
34
34
batch_arg_name : str = "batch_size" ,
35
+ margin : float = 0.05 ,
36
+ max_val : Optional [int ] = None ,
35
37
) -> Optional [int ]:
36
38
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
37
39
error.
@@ -58,6 +60,10 @@ def _scale_batch_size(
58
60
- ``model.hparams``
59
61
- ``trainer.datamodule`` (the datamodule passed to the tune method)
60
62
63
+ margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
64
+ 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
65
+ max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
66
+
61
67
"""
62
68
if trainer .fast_dev_run :
63
69
rank_zero_warn ("Skipping batch size scaler since `fast_dev_run` is enabled." )
@@ -79,9 +85,9 @@ def _scale_batch_size(
79
85
new_size , _ = _adjust_batch_size (trainer , batch_arg_name , value = init_val )
80
86
81
87
if mode == "power" :
82
- new_size = _run_power_scaling (trainer , new_size , batch_arg_name , max_trials , params )
88
+ new_size = _run_power_scaling (trainer , new_size , batch_arg_name , max_trials , params , max_val )
83
89
elif mode == "binsearch" :
84
- new_size = _run_binary_scaling (trainer , new_size , batch_arg_name , max_trials , params )
90
+ new_size = _run_binsearch_scaling (trainer , new_size , batch_arg_name , max_trials , params , margin , max_val )
85
91
86
92
garbage_collection_cuda ()
87
93
@@ -170,6 +176,7 @@ def _run_power_scaling(
170
176
batch_arg_name : str ,
171
177
max_trials : int ,
172
178
params : dict [str , Any ],
179
+ max_val : Optional [int ],
173
180
) -> int :
174
181
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
175
182
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -183,7 +190,9 @@ def _run_power_scaling(
183
190
184
191
try :
185
192
_try_loop_run (trainer , params )
186
- new_size , changed = _adjust_batch_size (trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" )
193
+ new_size , changed = _adjust_batch_size (
194
+ trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" , max_val = max_val
195
+ )
187
196
188
197
if not changed :
189
198
break
@@ -206,12 +215,14 @@ def _run_power_scaling(
206
215
return new_size
207
216
208
217
209
- def _run_binary_scaling (
218
+ def _run_binsearch_scaling (
210
219
trainer : "pl.Trainer" ,
211
220
new_size : int ,
212
221
batch_arg_name : str ,
213
222
max_trials : int ,
214
223
params : dict [str , Any ],
224
+ margin : float ,
225
+ max_val : Optional [int ],
215
226
) -> int :
216
227
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
217
228
@@ -239,9 +250,13 @@ def _run_binary_scaling(
239
250
if high - low <= 1 :
240
251
break
241
252
midval = (high + low ) // 2
242
- new_size , changed = _adjust_batch_size (trainer , batch_arg_name , value = midval , desc = "succeeded" )
253
+ new_size , changed = _adjust_batch_size (
254
+ trainer , batch_arg_name , value = midval , desc = "succeeded" , max_val = max_val
255
+ )
243
256
else :
244
- new_size , changed = _adjust_batch_size (trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" )
257
+ new_size , changed = _adjust_batch_size (
258
+ trainer , batch_arg_name , factor = 2.0 , desc = "succeeded" , max_val = max_val
259
+ )
245
260
246
261
if not changed :
247
262
break
@@ -267,6 +282,15 @@ def _run_binary_scaling(
267
282
else :
268
283
raise # some other error not memory related
269
284
285
+ # Apply margin reduction for binsearch mode
286
+ if margin > 0 :
287
+ margin_reduced_size = max (1 , int (new_size * (1 - margin )))
288
+ if margin_reduced_size != new_size :
289
+ rank_zero_info (
290
+ f"Applying margin of { margin :.1%} , reducing batch size from { new_size } to { margin_reduced_size } "
291
+ )
292
+ new_size = margin_reduced_size
293
+
270
294
return new_size
271
295
272
296
@@ -276,6 +300,7 @@ def _adjust_batch_size(
276
300
factor : float = 1.0 ,
277
301
value : Optional [int ] = None ,
278
302
desc : Optional [str ] = None ,
303
+ max_val : Optional [int ] = None ,
279
304
) -> tuple [int , bool ]:
280
305
"""Helper function for adjusting the batch size.
281
306
@@ -286,6 +311,7 @@ def _adjust_batch_size(
286
311
value: if a value is given, will override the batch size with this value.
287
312
Note that the value of `factor` will not have an effect in this case
288
313
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
314
+ max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value.
289
315
290
316
Returns:
291
317
The new batch size for the next trial and a bool that signals whether the
@@ -311,6 +337,12 @@ def _adjust_batch_size(
311
337
pass
312
338
313
339
new_size = value if value is not None else int (batch_size * factor )
340
+
341
+ # Apply max_val limit if provided
342
+ if max_val is not None and new_size > max_val :
343
+ if desc :
344
+ rank_zero_info (f"Batch size { new_size } exceeds max_val limit { max_val } , capping at { max_val } " )
345
+ new_size = max_val
314
346
if desc :
315
347
rank_zero_info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
316
348
changed = new_size != batch_size
0 commit comments