Skip to content

Commit 6bd71be

Browse files
rohitgr7lexierule
authored andcommitted
Reset dataloaders on failure in tuner (#14372)
1 parent 582b8cc commit 6bd71be

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Changed
1010

11-
- When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list ([#14325](https://github.com/Lightning-AI/lightning/pull/14325))
1211
- Improved the error messaging when passing `Trainer.method(model, x_dataloader=None)` with no module-method implementations available ([#14614](https://github.com/Lightning-AI/lightning/pull/14614))
1312

1413
### Fixed

src/pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def _run_power_scaling(
126126
trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int
127127
) -> int:
128128
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
129+
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
130+
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
131+
any_success = False
129132
for _ in range(max_trials):
130133
garbage_collection_cuda()
131134

@@ -137,22 +140,28 @@ def _run_power_scaling(
137140
trainer.tuner._run(model)
138141
# Double in size
139142
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
143+
144+
if not changed:
145+
break
146+
147+
# Force the train dataloader to reset as the batch size has changed
148+
trainer.reset_train_dataloader(model)
149+
trainer.reset_val_dataloader(model)
150+
any_success = True
140151
except RuntimeError as exception:
141152
# Only these errors should trigger an adjustment
142153
if is_oom_error(exception):
143154
# If we fail in power mode, half the size and return
144155
garbage_collection_cuda()
145156
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed")
146-
break
157+
# Force the train dataloader to reset as the batch size has changed
158+
trainer.reset_train_dataloader(model)
159+
trainer.reset_val_dataloader(model)
160+
if any_success:
161+
break
147162
else:
148163
raise # some other error not memory related
149164

150-
if changed:
151-
# Force the train dataloader to reset as the batch size has changed
152-
trainer.reset_train_dataloader(model)
153-
trainer.reset_val_dataloader(model)
154-
else:
155-
break
156165
return new_size
157166

158167

@@ -189,13 +198,13 @@ def _run_binsearch_scaling(
189198
else:
190199
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
191200

192-
if changed:
193-
# Force the train dataloader to reset as the batch size has changed
194-
trainer.reset_train_dataloader(model)
195-
trainer.reset_val_dataloader(model)
196-
else:
201+
if not changed:
197202
break
198203

204+
# Force the train dataloader to reset as the batch size has changed
205+
trainer.reset_train_dataloader(model)
206+
trainer.reset_val_dataloader(model)
207+
199208
except RuntimeError as exception:
200209
# Only these errors should trigger an adjustment
201210
if is_oom_error(exception):
@@ -204,6 +213,11 @@ def _run_binsearch_scaling(
204213
high = new_size
205214
midval = (high + low) // 2
206215
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed")
216+
217+
# Force the train dataloader to reset as the batch size has changed
218+
trainer.reset_train_dataloader(model)
219+
trainer.reset_val_dataloader(model)
220+
207221
if high - low <= 1:
208222
break
209223
else:

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,26 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
319319

320320
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
321321
assert trainer.val_dataloaders[0].batch_size == new_batch_size
322+
323+
324+
@pytest.mark.parametrize("scale_method, expected_batch_size", [("power", 62), ("binsearch", 100)])
325+
@patch("pytorch_lightning.tuner.batch_size_scaling.is_oom_error", return_value=True)
326+
def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expected_batch_size):
327+
class CustomBatchSizeModel(BatchSizeModel):
328+
def training_step(self, *_, **__):
329+
if self.batch_size > 100:
330+
raise RuntimeError
331+
332+
def train_dataloader(self):
333+
return DataLoader(RandomDataset(32, 1000), batch_size=self.batch_size)
334+
335+
model = CustomBatchSizeModel(batch_size=16)
336+
model.validation_step = None
337+
model.training_epoch_end = None
338+
scale_batch_size_kwargs = {"max_trials": 10, "steps_per_trial": 1, "init_val": 500, "mode": scale_method}
339+
340+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_scale_batch_size=True)
341+
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
342+
assert new_batch_size == model.batch_size
343+
assert new_batch_size == expected_batch_size
344+
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size

0 commit comments

Comments
 (0)