Skip to content

Commit a957d97

Browse files
Sean Narenlexierule
authored andcommitted
Pass args to ShardedDataParallel (#9483)
1 parent e17af8f commit a957d97

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
1111
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
12+
- Pass init args to ShardedDataParallel ([#9483](https://github.com/PyTorchLightning/pytorch-lightning/pull/9483))
1213

1314

1415
## [1.4.6] - 2021-09-07
@@ -30,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3031
- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))
3132
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))
3233

34+
3335
## [1.4.5] - 2021-08-31
3436

3537
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def configure_ddp(self):
4141
sharded_optimizer=self.lightning_module.trainer.optimizers,
4242
# For multi-node training, enabling bucketing will improve performance.
4343
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
44+
**self._ddp_kwargs
4445
)
4546
setattr(self._model, "require_backward_grad_sync", False)
4647

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
3636
def configure_ddp(self):
3737
self._wrap_optimizers()
3838
self._model = ShardedDataParallel(
39-
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
39+
LightningShardedDataParallel(self.model),
40+
sharded_optimizer=self.lightning_module.trainer.optimizers,
41+
**self._ddp_kwargs
4042
)
4143
setattr(self._model, "require_backward_grad_sync", False)
4244

tests/plugins/test_sharded_plugin.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,22 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
249249
model = ManualBoringModel()
250250
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
251251
trainer.fit(model)
252+
253+
254+
@RunIf(skip_windows=True, fairscale=True)
255+
@mock.patch("pytorch_lightning.plugins.DDPShardedPlugin._wrap_optimizers", autospec=True)
256+
@pytest.mark.parametrize("cls", [DDPShardedPlugin, DDPSpawnShardedPlugin])
257+
def test_custom_kwargs_sharded(tmpdir, cls):
258+
"""Tests to ensure that if custom kwargs are passed, they are set correctly."""
259+
plugin = cls(reduce_fp16=True)
260+
261+
class_name = "sharded" if isinstance(plugin, DDPShardedPlugin) else "sharded_spawn"
262+
263+
with mock.patch.object(plugin, "_model", autospec=True):
264+
with mock.patch(
265+
f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True
266+
) as mock_sharded:
267+
plugin.configure_ddp()
268+
args, kwargs = mock_sharded.call_args
269+
assert "reduce_fp16" in kwargs
270+
assert kwargs["reduce_fp16"]

0 commit comments

Comments
 (0)