Skip to content
22 changes: 22 additions & 0 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method
self._pl_static_graph_delay_done = False

@property
def is_distributed(self) -> bool: # pragma: no-cover
Expand Down Expand Up @@ -319,6 +320,27 @@ def pre_backward(self, closure_loss: Tensor) -> None:
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)

@override
def post_backward(self, closure_loss: Tensor) -> None:
# Only for first static-graph iteration with manual optimization
model = self.model
lm = self.lightning_module
if not isinstance(model, DistributedDataParallel):
return
if lm is None or lm.automatic_optimization:
return
if not getattr(model, "static_graph", False):
return
if self._pl_static_graph_delay_done:
return

# Call DDP's own first-iter static-graph flush.
# This is what actually launches the bucket all-reduces.
reducer = model.reducer
reducer._delay_all_reduce()

self._pl_static_graph_delay_done = True

@override
def model_to_device(self) -> None:
log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,50 @@ def creates_processes_externally(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize("automatic_optimization", [True, False])
@pytest.mark.parametrize("static_graph", [True, False])
def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph):
"""Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = automatic_optimization

def training_step(self, batch, batch_idx):
if self.automatic_optimization:
return super().training_step(batch, batch_idx)

# manual optimization path
opt = self.optimizers()
opt.zero_grad()
out = super().training_step(batch, batch_idx)
loss = out["loss"]
self.manual_backward(loss)
opt.step()
return out

def on_train_batch_end(self, *args, **kwargs):
# record grad sum for sync check
grad_sum = self.layer.bias.grad.detach().sum()
self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min")
self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max")

trainer = Trainer(
default_root_dir=tmp_path,
accelerator="gpu",
devices=2,
strategy=DDPStrategy(static_graph=static_graph),
max_steps=1,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(TestModel(), datamodule=BoringDataModule())

# assert all ranks saw identical grads
gmin = trainer.callback_metrics["grad_sum_min"]
gmax = trainer.callback_metrics["grad_sum_max"]
assert torch.allclose(gmin, gmax)
Loading