Skip to content

Commit e79ba85

Browse files
Sean Narenlexierule
authored andcommitted
Handle collision of user argument when using ShardedDDP (#9512)
* Handle collision of user argument * Add CHANGELOG.md
1 parent a957d97 commit e79ba85

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))
3333

3434

35+
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))
36+
37+
3538
## [1.4.5] - 2021-08-31
3639

3740
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ class DDPShardedPlugin(DDPPlugin):
3636

3737
def configure_ddp(self):
3838
self._wrap_optimizers()
39+
40+
if "reduce_buffer_size" not in self._ddp_kwargs:
41+
# For multi-node training, enabling bucketing will improve performance.
42+
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
43+
3944
self._model = ShardedDataParallel(
4045
LightningShardedDataParallel(self.model),
4146
sharded_optimizer=self.lightning_module.trainer.optimizers,
42-
# For multi-node training, enabling bucketing will improve performance.
43-
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
4447
**self._ddp_kwargs
4548
)
4649
setattr(self._model, "require_backward_grad_sync", False)

tests/plugins/test_sharded_plugin.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,27 @@ def test_custom_kwargs_sharded(tmpdir, cls):
268268
args, kwargs = mock_sharded.call_args
269269
assert "reduce_fp16" in kwargs
270270
assert kwargs["reduce_fp16"]
271+
272+
273+
@RunIf(skip_windows=True, fairscale=True)
274+
@mock.patch("pytorch_lightning.plugins.DDPShardedPlugin._wrap_optimizers", autospec=True)
275+
@pytest.mark.parametrize(["params", "expected_buffer_size"], [(dict(), 0), (dict(reduce_buffer_size=128), 128)])
276+
@pytest.mark.parametrize("num_nodes", [1, 2])
277+
def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffer_size, num_nodes):
278+
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
279+
plugin = DDPShardedPlugin(**params)
280+
plugin.num_nodes = num_nodes
281+
282+
with mock.patch.object(plugin, "_model", autospec=True):
283+
with mock.patch(
284+
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
285+
) as mock_sharded:
286+
plugin.configure_ddp()
287+
args, kwargs = mock_sharded.call_args
288+
assert "reduce_buffer_size" in kwargs
289+
290+
if num_nodes > 1 and len(params) == 0:
291+
# If user has not specified a buffer size and we're using multiple nodes, check to see if default is set
292+
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
293+
else:
294+
assert kwargs["reduce_buffer_size"] == expected_buffer_size

0 commit comments

Comments
 (0)