Skip to content

Commit ad0fbd7

Browse files
committed
add to tuner
1 parent a0664fb commit ad0fbd7

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/lightning/pytorch/tuner/tuning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def scale_batch_size(
4141
init_val: int = 2,
4242
max_trials: int = 25,
4343
batch_arg_name: str = "batch_size",
44+
margin: float = 0.05,
45+
max_val: Optional[int] = None,
4446
) -> Optional[int]:
4547
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
4648
error.
@@ -75,6 +77,10 @@ def scale_batch_size(
7577
- ``model.hparams``
7678
- ``trainer.datamodule`` (the datamodule passed to the tune method)
7779
80+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
81+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
82+
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
83+
7884
"""
7985
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
8086
_check_scale_batch_size_configuration(self._trainer)
@@ -88,6 +94,8 @@ def scale_batch_size(
8894
init_val=init_val,
8995
max_trials=max_trials,
9096
batch_arg_name=batch_arg_name,
97+
margin=margin,
98+
max_val=max_val,
9199
)
92100
# do not continue with the loop in case Tuner is used
93101
batch_size_finder._early_exit = True

0 commit comments

Comments
 (0)