File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed
src/lightning/pytorch/tuner Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -41,6 +41,8 @@ def scale_batch_size(
41
41
init_val : int = 2 ,
42
42
max_trials : int = 25 ,
43
43
batch_arg_name : str = "batch_size" ,
44
+ margin : float = 0.05 ,
45
+ max_val : Optional [int ] = None ,
44
46
) -> Optional [int ]:
45
47
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
46
48
error.
@@ -75,6 +77,10 @@ def scale_batch_size(
75
77
- ``model.hparams``
76
78
- ``trainer.datamodule`` (the datamodule passed to the tune method)
77
79
80
+ margin: Margin to reduce the found batch size by to provide a safety buffer. Only applied when using
81
+ 'binsearch' mode. Should be a float between 0 and 1. Defaults to 0.05 (5% reduction).
82
+ max_val: Maximum batch size limit. If provided, the found batch size will not exceed this value.
83
+
78
84
"""
79
85
_check_tuner_configuration (train_dataloaders , val_dataloaders , dataloaders , method )
80
86
_check_scale_batch_size_configuration (self ._trainer )
@@ -88,6 +94,8 @@ def scale_batch_size(
88
94
init_val = init_val ,
89
95
max_trials = max_trials ,
90
96
batch_arg_name = batch_arg_name ,
97
+ margin = margin ,
98
+ max_val = max_val ,
91
99
)
92
100
# do not continue with the loop in case Tuner is used
93
101
batch_size_finder ._early_exit = True
You can’t perform that action at this time.
0 commit comments