Skip to content

Commit a707438

Browse files
Sean Narenlexierule
authored andcommitted
[DeepSpeed] Do not fail if batch size could not be inferred for logging (#10438)
(cherry picked from commit e98ace3)
1 parent ae6da92 commit a707438

File tree

3 files changed

+23
-27
lines changed

3 files changed

+23
-27
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
1818
- Fixed scripting causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/pull/10470), [#10555](https://github.com/PyTorchLightning/pytorch-lightning/pull/10555))
1919

20+
### Changed
21+
22+
- Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438))
23+
2024

2125
## [1.5.1] - 2021-11-09
2226

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,6 @@ def _format_batch_size_and_grad_accum_config(self):
622622
)
623623
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
624624
if "train_micro_batch_size_per_gpu" not in self.config:
625-
rank_zero_warn(
626-
"Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. "
627-
"If you require skipping this, please pass "
628-
"`Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`"
629-
)
630625
batch_size = self._auto_select_batch_size()
631626
self.config["train_micro_batch_size_per_gpu"] = batch_size
632627
if "gradient_clipping" not in self.config:
@@ -638,9 +633,19 @@ def _auto_select_batch_size(self):
638633
batch_size = 1
639634
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
640635
if train_dl_source.is_defined():
641-
train_dataloader = train_dl_source.dataloader()
642-
if hasattr(train_dataloader, "batch_sampler"):
643-
batch_size = train_dataloader.batch_sampler.batch_size
636+
try:
637+
train_dataloader = train_dl_source.dataloader()
638+
if hasattr(train_dataloader, "batch_sampler"):
639+
batch_size = train_dataloader.batch_sampler.batch_size
640+
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
641+
# to have been called before
642+
except Exception:
643+
if self.global_rank == 0:
644+
deepspeed.utils.logging.logger.warning(
645+
"Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. "
646+
"To ensure DeepSpeed logging remains correct, please manually pass the plugin with the "
647+
"batch size, `Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`."
648+
)
644649
return batch_size
645650

646651
def _format_precision_config(self):

tests/plugins/test_deepspeed_plugin.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import json
3+
import logging
34
import os
45
from typing import Any, Dict, Optional
56
from unittest import mock
@@ -872,24 +873,9 @@ def training_step(self, batch, batch_idx):
872873
trainer.fit(model)
873874

874875

875-
@RunIf(min_gpus=1, deepspeed=True, special=True)
876-
def test_deepspeed_warn_train_dataloader_called(tmpdir):
877-
"""Test DeepSpeed warns when it calls ``lightning_module.train_dataloader`` internally for logging batch
878-
size."""
879-
model = BoringModel()
880-
trainer = Trainer(
881-
default_root_dir=tmpdir,
882-
strategy=DeepSpeedPlugin(),
883-
gpus=1,
884-
fast_dev_run=True,
885-
)
886-
with pytest.warns(UserWarning, match="Inferring the batch size for internal deepspeed logging"):
887-
trainer.fit(model)
888-
889-
890876
@RunIf(min_gpus=1, deepspeed=True, special=True)
891877
def test_deepspeed_setup_train_dataloader(tmpdir):
892-
"""Test DeepSpeed works when setup is required to call, and the user passes the batch size manually."""
878+
"""Test DeepSpeed works when setup is required to call in the DataModule."""
893879

894880
class TestSetupIsCalledDataModule(LightningDataModule):
895881
def __init__(self):
@@ -914,13 +900,14 @@ def test_dataloader(self):
914900
model = BoringModel()
915901
trainer = Trainer(
916902
default_root_dir=tmpdir,
917-
strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=32),
903+
strategy=DeepSpeedPlugin(logging_level=logging.INFO),
918904
gpus=1,
919905
fast_dev_run=True,
920906
)
921907
dm = TestSetupIsCalledDataModule()
922-
trainer.fit(model, datamodule=dm)
923-
trainer.test(model, datamodule=dm)
908+
with mock.patch("deepspeed.utils.logging.logger.warning", autospec=True) as mock_object:
909+
trainer.fit(model, datamodule=dm)
910+
assert any("Tried to infer the batch size" in str(arg) for arg in mock_object.call_args_list)
924911

925912

926913
@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)

0 commit comments

Comments
 (0)