Skip to content

Commit 6da5829

Browse files
author
Sean Naren
authored
DeepSpeed support for device IDs (#9847)
1 parent f16bfe9 commit 6da5829

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
489489
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))
490490

491491

492+
- Fixed DeepSpeed GPU device IDs ([#9847](https://github.com/PyTorchLightning/pytorch-lightning/pull/9847))
493+
494+
492495
- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))
493496

494497

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import argparse
1415
import contextlib
1516
import json
1617
import logging
@@ -429,6 +430,7 @@ def _initialize_deepspeed_train(self, model):
429430

430431
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
431432
model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
433+
args=argparse.Namespace(device_rank=self.root_device.index),
432434
config=self.config,
433435
model=model,
434436
model_parameters=model_parameters,
@@ -505,6 +507,7 @@ def _initialize_deepspeed_inference(self, model):
505507
# Remove all module hooks before initializing new model
506508
remove_module_hooks(model)
507509
model, _, _, _ = deepspeed.initialize(
510+
args=argparse.Namespace(device_rank=self.root_device.index),
508511
config=inference_config,
509512
model=model,
510513
optimizer=optimizer,

tests/plugins/test_deepspeed_plugin.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,40 @@ def test_different_accumulate_grad_batches_fails(tmpdir):
970970
MisconfigurationException, match="DeepSpeed currently does not support different `accumulate_grad_batches`"
971971
):
972972
trainer.fit(model)
973+
974+
975+
@RunIf(min_gpus=2, deepspeed=True, special=True)
976+
def test_specific_gpu_device_id(tmpdir):
977+
class TestCallback(Callback):
978+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
979+
assert model.device.index == 1
980+
981+
def on_train_batch_start(
982+
self,
983+
trainer: Trainer,
984+
pl_module: LightningModule,
985+
batch: Any,
986+
batch_idx: int,
987+
dataloader_idx: int,
988+
) -> None:
989+
assert batch.device.index == 1
990+
991+
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
992+
assert model.device.index == 1
993+
994+
def on_test_batch_start(
995+
self,
996+
trainer: Trainer,
997+
pl_module: LightningModule,
998+
batch: Any,
999+
batch_idx: int,
1000+
dataloader_idx: int,
1001+
) -> None:
1002+
assert batch.device.index == 1
1003+
1004+
model = BoringModel()
1005+
trainer = Trainer(
1006+
default_root_dir=tmpdir, fast_dev_run=True, gpus=[1], plugins="deepspeed", callbacks=TestCallback()
1007+
)
1008+
trainer.fit(model)
1009+
trainer.test(model)

0 commit comments

Comments
 (0)