@@ -656,37 +656,43 @@ def train_custom(
656656 batch_steps = self .get_batch_steps (batch , mini_batch_chunk_size = mini_batch_chunk_size )
657657
658658 # forward and backward for batch
659- for batch_step in batch_steps :
660- # forward pass
661- with torch .autocast (device_type = flair .device .type , enabled = use_amp ):
662- if multi_gpu :
663- # We need to __call__ ddp_model() because this triggers hooks that sync gradients.
664- # But that calls forward rather than forward_loss. So we patch forward to redirect
665- # to forward_loss. Then undo the patch in case forward_loss itself calls forward.
666- def wrapped_forward_loss (* args , ** kwargs2 ):
667- self .model .forward = original_forward
668- return self .model .forward_loss (* args , ** kwargs2 )
669-
670- self .model .forward = wrapped_forward_loss
671- loss , datapoint_count = self .ddp_model (batch_step )
672- else :
673- loss , datapoint_count = self .model .forward_loss (batch_step )
674-
675- batch_train_samples += datapoint_count
676- batch_train_loss += loss .item ()
677-
678- self ._backward (scaler .scale (loss ))
679-
680- # identify dynamic embeddings (always deleted) on first sentence
681- if dynamic_embeddings is None :
682- dynamic_embeddings = identify_dynamic_embeddings (batch )
683-
684- # depending on memory mode, embeddings are moved to CPU, GPU or deleted
685- store_embeddings (batch_step , embeddings_storage_mode , dynamic_embeddings )
659+ for batch_step_no , batch_step in enumerate (batch_steps ):
660+ disable_gradient_sync = multi_gpu and batch_step_no < len (batch_steps ) - 1
661+ grad_sync = self .ddp_model .no_sync () if disable_gradient_sync else contextlib .nullcontext ()
662+ with grad_sync :
663+ # forward pass
664+ with torch .autocast (device_type = flair .device .type , enabled = use_amp ):
665+ if multi_gpu :
666+ # We need to __call__ ddp_model() because this triggers hooks that sync gradients.
667+ # But that calls forward rather than forward_loss. So we patch forward to redirect
668+ # to forward_loss. Then undo the patch in case forward_loss itself calls forward.
669+ def wrapped_forward_loss (* args , ** kwargs2 ):
670+ self .model .forward = original_forward
671+ return self .model .forward_loss (* args , ** kwargs2 )
672+
673+ self .model .forward = wrapped_forward_loss
674+ loss , datapoint_count = self .ddp_model (batch_step )
675+ else :
676+ loss , datapoint_count = self .model .forward_loss (batch_step )
677+
678+ batch_train_samples += datapoint_count
679+ batch_train_loss += loss .item ()
680+
681+ self ._backward (scaler .scale (loss ))
682+
683+ # identify dynamic embeddings (always deleted) on first sentence
684+ if dynamic_embeddings is None :
685+ dynamic_embeddings = identify_dynamic_embeddings (batch )
686+
687+ # depending on memory mode, embeddings are moved to CPU, GPU or deleted
688+ store_embeddings (batch_step , embeddings_storage_mode , dynamic_embeddings )
686689
687690 self .dispatch ("before_training_optimizer_step" , ** batch_kw )
688691
689692 # do the optimizer step
693+ if multi_gpu :
694+ # DDP averages across processes but we want the sum
695+ self ._scale_gradients (torch .distributed .get_world_size ())
690696 scaler .unscale_ (self .optimizer )
691697 if max_grad_norm is not None :
692698 gradient_norm = torch .nn .utils .clip_grad_norm_ (self .model .parameters (), max_grad_norm )
@@ -985,3 +991,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
985991 self .model .save (model_file , checkpoint )
986992 if torch .distributed .is_initialized ():
987993 torch .distributed .barrier () # Prevent any process from loading a model until writing is complete
994+
995+ def _scale_gradients (self , constant ):
996+ for param in self .model .parameters ():
997+ if param .grad is not None :
998+ param .grad .data .mul_ (constant )
0 commit comments