Skip to content

Commit cef0d56

Browse files
committed
add to callback
1 parent ad0fbd7 commit cef0d56

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/lightning/pytorch/callbacks/batch_size_finder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class BatchSizeFinder(Callback):
6363
- ``model.hparams``
6464
- ``trainer.datamodule`` (the datamodule passed to the tune method)
6565
66+
margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
67+
'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
68+
max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
69+
6670
Example::
6771
6872
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
@@ -118,6 +122,8 @@ def __init__(
118122
init_val: int = 2,
119123
max_trials: int = 25,
120124
batch_arg_name: str = "batch_size",
125+
margin: float = 0.05,
126+
max_val: Optional[int] = None,
121127
) -> None:
122128
mode = mode.lower()
123129
if mode not in self.SUPPORTED_MODES:
@@ -129,6 +135,8 @@ def __init__(
129135
self._init_val = init_val
130136
self._max_trials = max_trials
131137
self._batch_arg_name = batch_arg_name
138+
self._margin = margin
139+
self._max_val = max_val
132140
self._early_exit = False
133141

134142
@override
@@ -180,6 +188,8 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
180188
self._init_val,
181189
self._max_trials,
182190
self._batch_arg_name,
191+
self._margin,
192+
self._max_val,
183193
)
184194

185195
self.optimal_batch_size = new_size

0 commit comments

Comments
 (0)