Skip to content

Commit a0664fb

Browse files
committed
new args to batch size scaler
1 parent cdc0db4 commit a0664fb

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _scale_batch_size(
3232
init_val: int = 2,
3333
max_trials: int = 25,
3434
batch_arg_name: str = "batch_size",
35+
margin: float = 0.05,
36+
max_val: Optional[int] = None,
3537
) -> Optional[int]:
3638
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
3739
error.
@@ -58,6 +60,10 @@ def _scale_batch_size(
5860
- ``model.hparams``
5961
- ``trainer.datamodule`` (the datamodule passed to the tune method)
6062
63+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
64+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
65+
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
66+
6167
"""
6268
if trainer.fast_dev_run:
6369
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
@@ -79,9 +85,9 @@ def _scale_batch_size(
7985
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
8086

8187
if mode == "power":
82-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
88+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val)
8389
elif mode == "binsearch":
84-
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
90+
new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val)
8591

8692
garbage_collection_cuda()
8793

@@ -170,6 +176,7 @@ def _run_power_scaling(
170176
batch_arg_name: str,
171177
max_trials: int,
172178
params: dict[str, Any],
179+
max_val: Optional[int],
173180
) -> int:
174181
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
175182
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -183,7 +190,9 @@ def _run_power_scaling(
183190

184191
try:
185192
_try_loop_run(trainer, params)
186-
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
193+
new_size, changed = _adjust_batch_size(
194+
trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val
195+
)
187196

188197
if not changed:
189198
break
@@ -206,12 +215,14 @@ def _run_power_scaling(
206215
return new_size
207216

208217

209-
def _run_binary_scaling(
218+
def _run_binsearch_scaling(
210219
trainer: "pl.Trainer",
211220
new_size: int,
212221
batch_arg_name: str,
213222
max_trials: int,
214223
params: dict[str, Any],
224+
margin: float,
225+
max_val: Optional[int],
215226
) -> int:
216227
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
217228
@@ -239,9 +250,13 @@ def _run_binary_scaling(
239250
if high - low <= 1:
240251
break
241252
midval = (high + low) // 2
242-
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded")
253+
new_size, changed = _adjust_batch_size(
254+
trainer, batch_arg_name, value=midval, desc="succeeded", max_val=max_val
255+
)
243256
else:
244-
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
257+
new_size, changed = _adjust_batch_size(
258+
trainer, batch_arg_name, factor=2.0, desc="succeeded", max_val=max_val
259+
)
245260

246261
if not changed:
247262
break
@@ -267,6 +282,15 @@ def _run_binary_scaling(
267282
else:
268283
raise # some other error not memory related
269284

285+
# Apply margin reduction for binsearch mode
286+
if margin > 0:
287+
margin_reduced_size = max(1, int(new_size * (1 - margin)))
288+
if margin_reduced_size != new_size:
289+
rank_zero_info(
290+
f"Applying margin of {margin:.1%}, reducing batch size from {new_size} to {margin_reduced_size}"
291+
)
292+
new_size = margin_reduced_size
293+
270294
return new_size
271295

272296

@@ -276,6 +300,7 @@ def _adjust_batch_size(
276300
factor: float = 1.0,
277301
value: Optional[int] = None,
278302
desc: Optional[str] = None,
303+
max_val: Optional[int] = None,
279304
) -> tuple[int, bool]:
280305
"""Helper function for adjusting the batch size.
281306
@@ -286,6 +311,7 @@ def _adjust_batch_size(
286311
value: if a value is given, will override the batch size with this value.
287312
Note that the value of `factor` will not have an effect in this case
288313
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
314+
max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value.
289315
290316
Returns:
291317
The new batch size for the next trial and a bool that signals whether the
@@ -311,6 +337,12 @@ def _adjust_batch_size(
311337
pass
312338

313339
new_size = value if value is not None else int(batch_size * factor)
340+
341+
# Apply max_val limit if provided
342+
if max_val is not None and new_size > max_val:
343+
if desc:
344+
rank_zero_info(f"Batch size {new_size} exceeds max_val limit {max_val}, capping at {max_val}")
345+
new_size = max_val
314346
if desc:
315347
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
316348
changed = new_size != batch_size

0 commit comments

Comments
 (0)