-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
It would be good to be able to configure a maximum batch size for the batch size finder search (both "power" mode and "binsearch" mode) such that when they exceed the limit, the search stops and the result is set to that configured maximum.
For me, this is very useful when training my transformers because of pytorch/pytorch#142228. For example, my training job crashes when using small transformers, as my GPU may not OOM when exceeding a batch size of 65535 but CUDA still raises a misconfiguration error.
Pitch
Just as there is an init_val
parameter in the BatchSizeFinder, also add a max_val
parameter. This could then be checked against when raising the batch size (I think this happens in lightning.pytorch.tuner._adjust_batch_size
?). It would probably be appropriate to add this to Tuner.scale_batch_size
and wherever else this is used. The default behavior could be to leave this max_val
at None
, and then not cap the batch size if the parameter is not set.
Alternatives
You could also add a specific check for this error in addition to the OOM checks, but that seems somewhat brittle.
Additional context
No response