Skip to content

Commit 638e8a8

Browse files
committed
safe default
1 parent 873de18 commit 638e8a8

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

src/lightning/pytorch/callbacks/batch_size_finder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class BatchSizeFinder(Callback):
6565
6666
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
6767
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
68-
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes. Defaults to 8192.
68+
max_val: Maximum batch size limit, defaults to 8192.
69+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
70+
when running on CPU or when automatic OOM detection is not available.
6971
7072
Example::
7173

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _scale_batch_size(
6262
6363
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
6464
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
65-
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes.
65+
max_val: Maximum batch size limit, defaults to 8192.
66+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
67+
when running on CPU or when automatic OOM detection is not available.
6668
6769
"""
6870
if trainer.fast_dev_run:
@@ -336,7 +338,9 @@ def _adjust_batch_size(
336338
value: if a value is given, will override the batch size with this value.
337339
Note that the value of `factor` will not have an effect in this case
338340
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
339-
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes.
341+
max_val: Maximum batch size limit, defaults to 8192.
342+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
343+
when running on CPU or when automatic OOM detection is not available.
340344
341345
Returns:
342346
The new batch size for the next trial and a bool that signals whether the

src/lightning/pytorch/tuner/tuning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def scale_batch_size(
4242
max_trials: int = 25,
4343
batch_arg_name: str = "batch_size",
4444
margin: float = 0.05,
45-
max_val: Optional[int] = None,
45+
max_val: int = 8192,
4646
) -> Optional[int]:
4747
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
4848
error.
@@ -79,7 +79,9 @@ def scale_batch_size(
7979
8080
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
8181
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
82-
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
82+
max_val: Maximum batch size limit, defaults to 8192.
83+
Helps prevent testing unrealistically large or inefficient batch sizes (e.g., 2**25)
84+
when running on CPU or when automatic OOM detection is not available.
8385
8486
"""
8587
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)

0 commit comments

Comments
 (0)