Skip to content

Commit 873de18

Browse files
committed
update
1 parent f2fa81e commit 873de18

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/lightning/pytorch/callbacks/batch_size_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class BatchSizeFinder(Callback):
6565
6666
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
6767
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
68-
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes.
68+
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes. Defaults to 8192.
6969
7070
Example::
7171
@@ -123,7 +123,7 @@ def __init__(
123123
max_trials: int = 25,
124124
batch_arg_name: str = "batch_size",
125125
margin: float = 0.05,
126-
max_val: int = 1024,
126+
max_val: int = 8192,
127127
) -> None:
128128
mode = mode.lower()
129129
if mode not in self.SUPPORTED_MODES:

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _scale_batch_size(
3333
max_trials: int = 25,
3434
batch_arg_name: str = "batch_size",
3535
margin: float = 0.05,
36-
max_val: int = 1024,
36+
max_val: int = 8192,
3737
) -> Optional[int]:
3838
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
3939
error.
@@ -179,7 +179,7 @@ def _run_power_scaling(
179179
batch_arg_name: str,
180180
max_trials: int,
181181
params: dict[str, Any],
182-
max_val: Optional[int],
182+
max_val: int = 8192,
183183
) -> int:
184184
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
185185
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
@@ -192,7 +192,7 @@ def _run_power_scaling(
192192
# reset after each try
193193
_reset_progress(trainer)
194194

195-
if max_val is not None and new_size >= max_val:
195+
if new_size >= max_val:
196196
rank_zero_info(f"Reached the maximum batch size limit of {max_val}. Stopping search.")
197197
break
198198

@@ -235,7 +235,7 @@ def _run_binsearch_scaling(
235235
max_trials: int,
236236
params: dict[str, Any],
237237
margin: float,
238-
max_val: Optional[int],
238+
max_val: int = 8192,
239239
) -> int:
240240
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
241241
@@ -252,7 +252,7 @@ def _run_binsearch_scaling(
252252
# reset after each try
253253
_reset_progress(trainer)
254254

255-
if max_val is not None and new_size >= max_val:
255+
if new_size >= max_val:
256256
rank_zero_info(f"Reached the maximum batch size limit of {max_val}. Stopping search.")
257257
break
258258

@@ -325,7 +325,7 @@ def _adjust_batch_size(
325325
factor: float = 1.0,
326326
value: Optional[int] = None,
327327
desc: Optional[str] = None,
328-
max_val: Optional[int] = None,
328+
max_val: int = 8192,
329329
) -> tuple[int, bool]:
330330
"""Helper function for adjusting the batch size.
331331
@@ -336,7 +336,7 @@ def _adjust_batch_size(
336336
value: if a value is given, will override the batch size with this value.
337337
Note that the value of `factor` will not have an effect in this case
338338
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
339-
max_val: Maximum batch size limit. If provided, the new batch size will not exceed this value.
339+
max_val: Maximum batch size limit, to prevent overly large or inefficient batch sizes.
340340
341341
Returns:
342342
The new batch size for the next trial and a bool that signals whether the
@@ -367,7 +367,7 @@ def _adjust_batch_size(
367367
new_size = value if value is not None else int(batch_size * factor)
368368

369369
# Apply max_val limit if provided
370-
if max_val is not None and new_size > max_val:
370+
if new_size > max_val:
371371
if desc:
372372
rank_zero_info(f"Batch size {new_size} exceeds max_val limit {max_val}, capping at {max_val}")
373373
new_size = max_val

0 commit comments

Comments
 (0)