@@ -657,9 +657,9 @@ def train_custom(
657657
658658 # forward and backward for batch
659659 for batch_step_no , batch_step in enumerate (batch_steps ):
660- skip_sync = multi_gpu and batch_step_no < len (batch_steps ) - 1
661- gradient_sync = contextlib . nullcontext () if skip_sync else self . ddp_model . no_sync ()
662- with gradient_sync :
660+ enable_gradient_sync = multi_gpu and batch_step_no == len (batch_steps ) - 1
661+ sync_context = self . ddp_model . no_sync () if enable_gradient_sync else contextlib . nullcontext ()
662+ with sync_context :
663663 # forward pass
664664 with torch .autocast (device_type = flair .device .type , enabled = use_amp ):
665665 if multi_gpu :
@@ -690,6 +690,8 @@ def wrapped_forward_loss(*args, **kwargs2):
690690 self .dispatch ("before_training_optimizer_step" , ** batch_kw )
691691
692692 # do the optimizer step
693+ if multi_gpu :
694+ self ._scale_gradients (torch .distributed .get_world_size ()) # DDP averages across processes but we want the sum
693695 scaler .unscale_ (self .optimizer )
694696 if max_grad_norm is not None :
695697 gradient_norm = torch .nn .utils .clip_grad_norm_ (self .model .parameters (), max_grad_norm )
@@ -988,3 +990,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
988990 self .model .save (model_file , checkpoint )
989991 if torch .distributed .is_initialized ():
990992 torch .distributed .barrier () # Prevent any process from loading a model until writing is complete
993+
994+ def _scale_gradients (self , constant ):
995+ for param in self .model .parameters ():
996+ if param .grad is not None :
997+ param .grad .data .mul_ (constant )
0 commit comments