-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Description & Motivation
My model allocates a bit more space for some examples than other. When I run BatchSizeFinder, the batch size it discovers works for a while, but eventually runs into a larger example and I run out of memory.
Pitch
I would like to add an argument to BatchSizeFinder that decreases the final batch size by a multiplicative factor. Possibilities are:
safety_marginof 0.1 would mean multiplying final batch size * (1-0.1)scaleof 0.9 would mean multiplying final batch size by 0.9
This would be turned off by default.
Alternatives
I implemented a pretty hacky version of this using inheritance:
# Adds a safety margin. For example, `safety_margin` of 0.1 indicates that
# the final batch_size will be reduced by 10%
class SafeBatchSizeFinder(BatchSizeFinder):
def __init__(self, safety_margin=0.1, *args, **kwargs):
super().__init__(*args, **kwargs)
assert safety_margin >= 0 and safety_margin <= 1.0
self.safety_margin = safety_margin
def scale_batch_size(self, trainer, *args, **kwargs):
super().scale_batch_size(trainer, *args, **kwargs)
original_batch_size = self.optimal_batch_size
new_batch_size = int(self.optimal_batch_size * (1.0 - self.safety_margin))
print(
f"Found optimal batch size of {original_batch_size}, but with a safety margin of {self.safety_margin}, reducing it to {new_batch_size}"
)
self.optimal_batch_size = new_batch_size
# This adjusts the data module batch_size.
pl.tuner.batch_size_scaling._adjust_batch_size(trainer, value=new_batch_size)
pl.tuner.batch_size_scaling._reset_dataloaders(trainer)
trainer._active_loop.reset()Additional context
I am willing to implement this. I don't think it would be hard.