Skip to content

Commit cee7043

Browse files
kaushikb11lexierule
authored andcommitted
Fix distributed types support for CPUs (#8667)
1 parent c4500f7 commit cee7043

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,12 +779,13 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
779779
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
780780
# DP and DDP2 cannot run without GPU
781781
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu:
782-
rank_zero_warn(
783-
"You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`."
784-
)
785-
# todo: in some cases it yield in comparison None and int
782+
786783
if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1):
787-
self._distrib_type = DistributedType.DDP
784+
if self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
785+
rank_zero_warn(
786+
f"{self._distrib_type} is not supported on CPUs, hence setting the distributed type to `ddp`."
787+
)
788+
self._distrib_type = DistributedType.DDP
788789
else:
789790
rank_zero_warn("You are running on single node with no parallelization, so distributed has no effect.")
790791
self._distrib_type = None

tests/accelerators/test_accelerator_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SLURMEnvironment,
4343
TorchElasticEnvironment,
4444
)
45+
from pytorch_lightning.utilities import DistributedType
4546
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4647
from tests.helpers.boring_model import BoringModel
4748
from tests.helpers.runif import RunIf
@@ -613,3 +614,12 @@ def test_devices_with_cpu_only_supports_integer():
613614

614615
with pytest.raises(MisconfigurationException, match="The flag `devices` only supports integer"):
615616
Trainer(accelerator="cpu", devices="1,3")
617+
618+
619+
@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
620+
def test_unsupported_distrib_types_on_cpu(training_type):
621+
622+
with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting the distributed type to `ddp`."):
623+
trainer = Trainer(accelerator=training_type, num_processes=2)
624+
625+
assert trainer._distrib_type == DistributedType.DDP

tests/trainer/test_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
11411141
dict(accelerator="ddp2", gpus=2),
11421142
dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1),
11431143
),
1144+
(
1145+
dict(accelerator="ddp2", num_processes=2, gpus=None),
1146+
dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2),
1147+
),
1148+
(
1149+
dict(accelerator="dp", num_processes=2, gpus=None),
1150+
dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2),
1151+
),
11441152
],
11451153
)
11461154
def test_trainer_config(trainer_kwargs, expected, monkeypatch):

0 commit comments

Comments
 (0)