Skip to content

Commit 927167e

Browse files
committed
add testing
1 parent c253280 commit 927167e

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

tests/tests_fabric/strategies/test_ddp.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,44 @@ def test_set_timeout(init_process_group_mock):
169169
global_rank = strategy.cluster_environment.global_rank()
170170
world_size = strategy.cluster_environment.world_size()
171171
init_process_group_mock.assert_called_with(
172-
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
172+
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, device_id=None
173+
)
174+
175+
176+
@mock.patch("torch.distributed.init_process_group")
177+
def test_device_id_passed_for_cuda_devices(init_process_group_mock):
178+
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
179+
# Test with CPU device - device_id should be None
180+
cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")])
181+
cpu_strategy.cluster_environment = LightningEnvironment()
182+
cpu_strategy.accelerator = Mock()
183+
cpu_strategy.setup_environment()
184+
185+
process_group_backend = cpu_strategy._get_process_group_backend()
186+
global_rank = cpu_strategy.cluster_environment.global_rank()
187+
world_size = cpu_strategy.cluster_environment.world_size()
188+
189+
init_process_group_mock.assert_called_with(
190+
process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, device_id=None
191+
)
192+
193+
init_process_group_mock.reset_mock()
194+
195+
# Test with CUDA device - device_id should be the device
196+
cuda_device = torch.device("cuda", 0)
197+
cuda_strategy = DDPStrategy(parallel_devices=[cuda_device])
198+
cuda_strategy.cluster_environment = LightningEnvironment()
199+
cuda_strategy.accelerator = Mock()
200+
cuda_strategy.setup_environment()
201+
202+
process_group_backend = cuda_strategy._get_process_group_backend()
203+
global_rank = cuda_strategy.cluster_environment.global_rank()
204+
world_size = cuda_strategy.cluster_environment.world_size()
205+
206+
init_process_group_mock.assert_called_with(
207+
process_group_backend,
208+
rank=global_rank,
209+
world_size=world_size,
210+
timeout=cuda_strategy._timeout,
211+
device_id=cuda_device,
173212
)

tests/tests_pytorch/strategies/test_ddp.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,34 @@ def test_set_timeout(mock_init_process_group):
133133
global_rank = trainer.strategy.cluster_environment.global_rank()
134134
world_size = trainer.strategy.cluster_environment.world_size()
135135
mock_init_process_group.assert_called_with(
136-
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
136+
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, device_id=None
137+
)
138+
139+
140+
@mock.patch("torch.distributed.init_process_group")
141+
def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group):
142+
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
143+
# Test with CPU device - device_id should be None
144+
model = BoringModel()
145+
ddp_strategy = DDPStrategy()
146+
trainer = Trainer(
147+
max_epochs=1,
148+
accelerator="cpu",
149+
strategy=ddp_strategy,
150+
)
151+
trainer.strategy.connect(model)
152+
trainer.lightning_module.trainer = trainer
153+
trainer.strategy.setup_environment()
154+
155+
process_group_backend = trainer.strategy._get_process_group_backend()
156+
global_rank = trainer.strategy.cluster_environment.global_rank()
157+
world_size = trainer.strategy.cluster_environment.world_size()
158+
mock_init_process_group.assert_called_with(
159+
process_group_backend,
160+
rank=global_rank,
161+
world_size=world_size,
162+
timeout=trainer.strategy._timeout,
163+
device_id=None,
137164
)
138165

139166

0 commit comments

Comments
 (0)