Skip to content

Commit e2a02cd

Browse files
fix: synchronize gradients in manual optimization with DDPStrategy(static_graph=True). Ensure gradients are reduced correctly when using manual optimization and DDP with static_graph enabled.
1 parent 3726e54 commit e2a02cd

File tree

1 file changed

+21
-0
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+21
-0
lines changed

src/lightning/pytorch/strategies/ddp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,27 @@ def pre_backward(self, closure_loss: Tensor) -> None:
319319
if not self.lightning_module.automatic_optimization:
320320
prepare_for_backward(self.model, closure_loss)
321321

322+
@override
323+
def post_backward(self, closure_loss: Tensor) -> None:
324+
# Only for first static-graph iteration with manual optimization
325+
model = self.model
326+
lm = self.lightning_module
327+
if not isinstance(model, DistributedDataParallel):
328+
return
329+
if lm is None or lm.automatic_optimization:
330+
return
331+
if not getattr(model, "static_graph", False):
332+
return
333+
if self._pl_static_graph_delay_done:
334+
return
335+
336+
# Call DDP's own first-iter static-graph flush.
337+
# This is what actually launches the bucket all-reduces.
338+
reducer = model.reducer
339+
reducer._delay_all_reduce()
340+
341+
self._pl_static_graph_delay_done = True
342+
322343
@override
323344
def model_to_device(self) -> None:
324345
log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")

0 commit comments

Comments
 (0)