Skip to content

Commit 7239a42

Browse files
kaushikb11lexierule
authored andcommitted
Fix ddp accelerator choice for cpu (#8645)
* Fix ddp accelerator choice for cpu
1 parent 7374bc8 commit 7239a42

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
607607
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
608608
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
609609
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
610-
use_ddp_cpu_spawn = self.use_ddp and self.use_cpu
610+
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
611611
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
612612
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
613613
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
@@ -738,14 +738,16 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
738738
if self.distributed_backend is None:
739739
if self.has_horovodrun():
740740
self._set_horovod_backend()
741-
elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1):
741+
elif self.num_gpus == 0 and self.num_nodes > 1:
742742
self._distrib_type = DistributedType.DDP
743+
elif self.num_gpus == 0 and self.num_processes > 1:
744+
self.distributed_backend = DistributedType.DDP_SPAWN
743745
elif self.num_gpus > 1 and not _use_cpu:
744746
rank_zero_warn(
745747
"You requested multiple GPUs but did not specify a backend, e.g."
746748
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
747749
)
748-
self.distributed_backend = "ddp_spawn"
750+
self.distributed_backend = DistributedType.DDP_SPAWN
749751

750752
# special case with DDP on CPUs
751753
if self.distributed_backend == "ddp_cpu":

tests/accelerators/test_accelerator_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,9 @@ def test_unsupported_distrib_types_on_cpu(training_type):
623623
trainer = Trainer(accelerator=training_type, num_processes=2)
624624

625625
assert trainer._distrib_type == DistributedType.DDP
626+
627+
628+
def test_accelerator_ddp_for_cpu(tmpdir):
629+
trainer = Trainer(accelerator="ddp", num_processes=2)
630+
assert isinstance(trainer.accelerator, CPUAccelerator)
631+
assert isinstance(trainer.training_type_plugin, DDPPlugin)

tests/trainer/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1737,7 +1737,7 @@ def on_predict_start(self) -> None:
17371737

17381738

17391739
@pytest.mark.parametrize(
1740-
"accelerator,num_processes", [(None, 1), pytest.param("ddp", 2, marks=RunIf(skip_windows=True))]
1740+
"accelerator,num_processes", [(None, 1), pytest.param("ddp_cpu", 2, marks=RunIf(skip_windows=True))]
17411741
)
17421742
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
17431743
model = TrainerStagesModel()

0 commit comments

Comments
 (0)