Skip to content

Commit 0bd69c9

Browse files
rohitgr7ananthsub
authored andcommitted
Fix to avoid moving batch to device for DataParallel (#11780)
Co-authored-by: ananthsub <[email protected]>
1 parent e17dd29 commit 0bd69c9

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
1414
- Fixed an issue to make the `step` argument in `WandbLogger.log_image` work ([#11716](https://github.com/PyTorchLightning/pytorch-lightning/pull/11716))
1515
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
16+
- With `DPStrategy`, the batch is not explictly moved to the device ([#11780](https://github.com/PyTorchLightning/pytorch-lightning/pull/11780))
17+
1618

1719

1820
## [1.5.9] - 2022-01-18

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat
201201
model = self.lightning_module
202202
device = device or self.root_device
203203

204-
if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
205-
# no need to transfer batch to device in DP mode
204+
# no need to transfer batch to device in DP mode
205+
if isinstance(self.training_type_plugin, DataParallelPlugin):
206+
return batch
207+
208+
if model is not None:
206209
return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
207210

208211
return move_data_to_device(batch, device)

tests/accelerators/test_dp.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,26 @@ def test_dp_training_step_dict(tmpdir):
196196
strategy="dp",
197197
)
198198
trainer.fit(model)
199+
trainer.test(model)
200+
201+
202+
@RunIf(min_gpus=2)
203+
def test_dp_batch_not_moved_to_device_explictly(tmpdir):
204+
"""Test that with DP, batch is not moved to the device explictly."""
205+
206+
class CustomModel(BoringModel):
207+
def on_train_batch_start(self, batch, *args, **kargs):
208+
assert not batch.is_cuda
209+
210+
def training_step(self, batch, batch_idx):
211+
assert batch.is_cuda
212+
return super().training_step(batch, batch_idx)
213+
214+
trainer = pl.Trainer(
215+
default_root_dir=tmpdir,
216+
fast_dev_run=True,
217+
gpus=2,
218+
strategy="dp",
219+
)
220+
221+
trainer.fit(CustomModel())

0 commit comments

Comments
 (0)