@@ -448,3 +448,50 @@ def creates_processes_externally(self):
448
448
RuntimeError , match = "Lightning attempted to launch new distributed processes with `local_rank > 0`."
449
449
):
450
450
trainer .fit (model )
451
+
452
+
453
+ @RunIf (min_cuda_gpus = 2 , standalone = True )
454
+ @pytest .mark .parametrize ("automatic_optimization" , [True , False ])
455
+ @pytest .mark .parametrize ("static_graph" , [True , False ])
456
+ def test_ddp_gradients_synced (tmp_path , automatic_optimization , static_graph ):
457
+ """Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings."""
458
+
459
+ class TestModel (BoringModel ):
460
+ def __init__ (self ):
461
+ super ().__init__ ()
462
+ self .automatic_optimization = automatic_optimization
463
+
464
+ def training_step (self , batch , batch_idx ):
465
+ if self .automatic_optimization :
466
+ return super ().training_step (batch , batch_idx )
467
+
468
+ # manual optimization path
469
+ opt = self .optimizers ()
470
+ opt .zero_grad ()
471
+ out = super ().training_step (batch , batch_idx )
472
+ loss = out ["loss" ]
473
+ self .manual_backward (loss )
474
+ opt .step ()
475
+ return out
476
+
477
+ def on_train_batch_end (self , * args , ** kwargs ):
478
+ # record grad sum for sync check
479
+ grad_sum = self .layer .bias .grad .detach ().sum ()
480
+ self .log ("grad_sum_min" , grad_sum , sync_dist = True , reduce_fx = "min" )
481
+ self .log ("grad_sum_max" , grad_sum , sync_dist = True , reduce_fx = "max" )
482
+
483
+ trainer = Trainer (
484
+ default_root_dir = tmp_path ,
485
+ accelerator = "gpu" ,
486
+ devices = 2 ,
487
+ strategy = DDPStrategy (static_graph = static_graph ),
488
+ max_steps = 1 ,
489
+ enable_progress_bar = False ,
490
+ enable_model_summary = False ,
491
+ )
492
+ trainer .fit (TestModel (), datamodule = BoringDataModule ())
493
+
494
+ # assert all ranks saw identical grads
495
+ gmin = trainer .callback_metrics ["grad_sum_min" ]
496
+ gmax = trainer .callback_metrics ["grad_sum_max" ]
497
+ assert torch .allclose (gmin , gmax )
0 commit comments