Skip to content

Commit 7049f09

Browse files
Sean Narenlexierule
authored andcommitted
Introduce parameter to fix deepspeed crash for RNNS (#9489)
1 parent 845a179 commit 7049f09

File tree

3 files changed

+63
-5
lines changed

3 files changed

+63
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
1212
- Pass init args to ShardedDataParallel ([#9483](https://github.com/PyTorchLightning/pytorch-lightning/pull/9483))
1313
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))
14+
- Fixed DeepSpeed crash for RNNs ([#9489](https://github.com/PyTorchLightning/pytorch-lightning/pull/9489))
1415

1516

1617
## [1.4.6] - 2021-09-07

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
cpu_offload: bool = False,
124124
cpu_offload_params: bool = False,
125125
cpu_offload_use_pin_memory: bool = False,
126+
partition_module: bool = True,
126127
) -> None:
127128
"""
128129
Provides capabilities to run training using the DeepSpeed library,
@@ -254,6 +255,12 @@ def __init__(
254255
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
255256
rather than individual sharded weight files.
256257
Disable to save sharded states individually.
258+
259+
partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3.
260+
This is the default behaviour to ensure that the entire module is appropriately initialized
261+
for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers
262+
or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as
263+
``torch.nn.RNN`` which do internal logic when moving to device.
257264
"""
258265
if not _DEEPSPEED_AVAILABLE:
259266
raise MisconfigurationException(
@@ -314,6 +321,7 @@ def __init__(
314321

315322
self.remote_device = remote_device
316323
self.save_full_weights = save_full_weights
324+
self.partition_module = partition_module
317325

318326
# default FP16 parameters.
319327
self.loss_scale = loss_scale
@@ -375,7 +383,7 @@ def init_deepspeed(self):
375383
precision = self.lightning_module.trainer.accelerator.precision
376384
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
377385

378-
if self.zero_stage_3:
386+
if self.zero_stage_3 and self.partition_module:
379387
# Ensure the entire model has been moved to the appropriate device
380388
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
381389
deepspeed.zero.Init(

tests/plugins/test_deepspeed_plugin.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -643,17 +643,66 @@ def on_train_batch_start(
643643

644644

645645
@RunIf(min_gpus=2, deepspeed=True, special=True)
646-
def test_deepspeed_multigpu_test(tmpdir, deepspeed_config):
647-
"""
648-
Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.
649-
"""
646+
def test_deepspeed_multigpu_test(tmpdir):
647+
"""Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3."""
650648
model = ModelParallelBoringModel()
651649
trainer = Trainer(
652650
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
653651
)
654652
trainer.test(model)
655653

656654

655+
@RunIf(min_gpus=1, deepspeed=True, special=True)
656+
def test_deepspeed_multigpu_partial_partition_parameters(tmpdir):
657+
"""Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model``
658+
correctly converts all parameters to float16 when ``precision=16`` and runs successfully."""
659+
660+
class TestModel(ModelParallelBoringModel):
661+
def __init__(self):
662+
super().__init__()
663+
self.layer_2 = torch.nn.Linear(32, 32)
664+
665+
def configure_sharded_model(self) -> None:
666+
self.layer = torch.nn.Linear(32, 2)
667+
668+
def forward(self, x):
669+
x = self.layer_2(x)
670+
return self.layer(x)
671+
672+
def on_train_epoch_start(self) -> None:
673+
assert all([x.dtype == torch.float16 for x in self.parameters()])
674+
675+
model = TestModel()
676+
trainer = Trainer(
677+
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16
678+
)
679+
trainer.fit(model)
680+
681+
682+
@RunIf(min_gpus=1, deepspeed=True, special=True)
683+
def test_deepspeed_multigpu_test_rnn(tmpdir):
684+
"""Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when
685+
training with certain layers which will crash with explicit partitioning."""
686+
687+
class TestModel(BoringModel):
688+
def __init__(self):
689+
super().__init__()
690+
self.rnn = torch.nn.GRU(32, 32)
691+
692+
def on_train_epoch_start(self) -> None:
693+
assert all([x.dtype == torch.float16 for x in self.parameters()])
694+
695+
model = TestModel()
696+
trainer = Trainer(
697+
default_root_dir=tmpdir,
698+
plugins=[DeepSpeedPlugin(stage=3, partition_module=False)],
699+
gpus=1,
700+
fast_dev_run=True,
701+
precision=16,
702+
)
703+
trainer.fit(model)
704+
705+
657706
@RunIf(deepspeed=True)
658707
@mock.patch("deepspeed.init_distributed", autospec=True)
659708
@pytest.mark.parametrize("platform", ["Linux", "Windows"])

0 commit comments

Comments
 (0)