Skip to content

Allow max batch size in BatchSizeFinder #20614

@rmeno12

Description

@rmeno12

Description & Motivation

It would be good to be able to configure a maximum batch size for the batch size finder search (both "power" mode and "binsearch" mode) such that when they exceed the limit, the search stops and the result is set to that configured maximum.
For me, this is very useful when training my transformers because of pytorch/pytorch#142228. For example, my training job crashes when using small transformers, as my GPU may not OOM when exceeding a batch size of 65535 but CUDA still raises a misconfiguration error.

Pitch

Just as there is an init_val parameter in the BatchSizeFinder, also add a max_val parameter. This could then be checked against when raising the batch size (I think this happens in lightning.pytorch.tuner._adjust_batch_size?). It would probably be appropriate to add this to Tuner.scale_batch_size and wherever else this is used. The default behavior could be to leave this max_val at None, and then not cap the batch size if the parameter is not set.

Alternatives

You could also add a specific check for this error in addition to the OOM checks, but that seems somewhat brittle.

Additional context

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancementtuner

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions