Skip to content

Commit 38cc242

Browse files
Adds regression test to cover all combinations of optimization/static_graph.
1 parent e2a02cd commit 38cc242

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

tests/tests_pytorch/strategies/test_ddp_integration.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,50 @@ def creates_processes_externally(self):
448448
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
449449
):
450450
trainer.fit(model)
451+
452+
453+
@RunIf(min_cuda_gpus=2, standalone=True)
454+
@pytest.mark.parametrize("automatic_optimization", [True, False])
455+
@pytest.mark.parametrize("static_graph", [True, False])
456+
def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph):
457+
"""Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings."""
458+
459+
class TestModel(BoringModel):
460+
def __init__(self):
461+
super().__init__()
462+
self.automatic_optimization = automatic_optimization
463+
464+
def training_step(self, batch, batch_idx):
465+
if self.automatic_optimization:
466+
return super().training_step(batch, batch_idx)
467+
468+
# manual optimization path
469+
opt = self.optimizers()
470+
opt.zero_grad()
471+
out = super().training_step(batch, batch_idx)
472+
loss = out["loss"]
473+
self.manual_backward(loss)
474+
opt.step()
475+
return out
476+
477+
def on_train_batch_end(self, *args, **kwargs):
478+
# record grad sum for sync check
479+
grad_sum = self.layer.bias.grad.detach().sum()
480+
self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min")
481+
self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max")
482+
483+
trainer = Trainer(
484+
default_root_dir=tmp_path,
485+
accelerator="gpu",
486+
devices=2,
487+
strategy=DDPStrategy(static_graph=static_graph),
488+
max_steps=1,
489+
enable_progress_bar=False,
490+
enable_model_summary=False,
491+
)
492+
trainer.fit(TestModel(), datamodule=BoringDataModule())
493+
494+
# assert all ranks saw identical grads
495+
gmin = trainer.callback_metrics["grad_sum_min"]
496+
gmax = trainer.callback_metrics["grad_sum_max"]
497+
assert torch.allclose(gmin, gmax)

0 commit comments

Comments
 (0)