diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 92206e1accc31..4eca6159ddced 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -103,6 +103,7 @@ def __init__( self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._start_method = start_method + self._pl_static_graph_delay_done = False @property def is_distributed(self) -> bool: # pragma: no-cover @@ -319,6 +320,27 @@ def pre_backward(self, closure_loss: Tensor) -> None: if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) + @override + def post_backward(self, closure_loss: Tensor) -> None: + # Only for first static-graph iteration with manual optimization + model = self.model + lm = self.lightning_module + if not isinstance(model, DistributedDataParallel): + return + if lm is None or lm.automatic_optimization: + return + if not getattr(model, "static_graph", False): + return + if self._pl_static_graph_delay_done: + return + + # Call DDP's own first-iter static-graph flush. + # This is what actually launches the bucket all-reduces. + reducer = model.reducer + reducer._delay_all_reduce() + + self._pl_static_graph_delay_done = True + @override def model_to_device(self) -> None: log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index fc3a8cfebbac0..6373985687ad3 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -448,3 +448,50 @@ def creates_processes_externally(self): RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." ): trainer.fit(model) + + +@RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize("automatic_optimization", [True, False]) +@pytest.mark.parametrize("static_graph", [True, False]) +def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph): + """Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.automatic_optimization = automatic_optimization + + def training_step(self, batch, batch_idx): + if self.automatic_optimization: + return super().training_step(batch, batch_idx) + + # manual optimization path + opt = self.optimizers() + opt.zero_grad() + out = super().training_step(batch, batch_idx) + loss = out["loss"] + self.manual_backward(loss) + opt.step() + return out + + def on_train_batch_end(self, *args, **kwargs): + # record grad sum for sync check + grad_sum = self.layer.bias.grad.detach().sum() + self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min") + self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max") + + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=2, + strategy=DDPStrategy(static_graph=static_graph), + max_steps=1, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(TestModel(), datamodule=BoringDataModule()) + + # assert all ranks saw identical grads + gmin = trainer.callback_metrics["grad_sum_min"] + gmax = trainer.callback_metrics["grad_sum_max"] + assert torch.allclose(gmin, gmax)