Skip to content

Commit 0becfed

Browse files
authored
Merge pull request #3583 from ZipRecruiter/jeffp.multi-gpu-fixes
Multigpu: Fix gradient accumulation and learning rate aggregation
2 parents 29feea4 + a2edb9e commit 0becfed

File tree

2 files changed

+38
-29
lines changed

2 files changed

+38
-29
lines changed

flair/trainers/plugins/functional/checkpoints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw):
3030
)
3131
model_name = "model_epoch_" + str(epoch) + ".pt"
3232
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
33-
if torch.distributed.is_initialized():
34-
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
3533

3634
@property
3735
def attach_to_all_processes(self) -> bool:

flair/trainers/trainer.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -656,37 +656,43 @@ def train_custom(
656656
batch_steps = self.get_batch_steps(batch, mini_batch_chunk_size=mini_batch_chunk_size)
657657

658658
# forward and backward for batch
659-
for batch_step in batch_steps:
660-
# forward pass
661-
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
662-
if multi_gpu:
663-
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
664-
# But that calls forward rather than forward_loss. So we patch forward to redirect
665-
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
666-
def wrapped_forward_loss(*args, **kwargs2):
667-
self.model.forward = original_forward
668-
return self.model.forward_loss(*args, **kwargs2)
669-
670-
self.model.forward = wrapped_forward_loss
671-
loss, datapoint_count = self.ddp_model(batch_step)
672-
else:
673-
loss, datapoint_count = self.model.forward_loss(batch_step)
674-
675-
batch_train_samples += datapoint_count
676-
batch_train_loss += loss.item()
677-
678-
self._backward(scaler.scale(loss))
679-
680-
# identify dynamic embeddings (always deleted) on first sentence
681-
if dynamic_embeddings is None:
682-
dynamic_embeddings = identify_dynamic_embeddings(batch)
683-
684-
# depending on memory mode, embeddings are moved to CPU, GPU or deleted
685-
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)
659+
for batch_step_no, batch_step in enumerate(batch_steps):
660+
disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
661+
grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext()
662+
with grad_sync:
663+
# forward pass
664+
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
665+
if multi_gpu:
666+
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
667+
# But that calls forward rather than forward_loss. So we patch forward to redirect
668+
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
669+
def wrapped_forward_loss(*args, **kwargs2):
670+
self.model.forward = original_forward
671+
return self.model.forward_loss(*args, **kwargs2)
672+
673+
self.model.forward = wrapped_forward_loss
674+
loss, datapoint_count = self.ddp_model(batch_step)
675+
else:
676+
loss, datapoint_count = self.model.forward_loss(batch_step)
677+
678+
batch_train_samples += datapoint_count
679+
batch_train_loss += loss.item()
680+
681+
self._backward(scaler.scale(loss))
682+
683+
# identify dynamic embeddings (always deleted) on first sentence
684+
if dynamic_embeddings is None:
685+
dynamic_embeddings = identify_dynamic_embeddings(batch)
686+
687+
# depending on memory mode, embeddings are moved to CPU, GPU or deleted
688+
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)
686689

687690
self.dispatch("before_training_optimizer_step", **batch_kw)
688691

689692
# do the optimizer step
693+
if multi_gpu:
694+
# DDP averages across processes but we want the sum
695+
self._scale_gradients(torch.distributed.get_world_size())
690696
scaler.unscale_(self.optimizer)
691697
if max_grad_norm is not None:
692698
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
@@ -985,3 +991,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
985991
self.model.save(model_file, checkpoint)
986992
if torch.distributed.is_initialized():
987993
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
994+
995+
def _scale_gradients(self, constant):
996+
for param in self.model.parameters():
997+
if param.grad is not None:
998+
param.grad.data.mul_(constant)

0 commit comments

Comments
 (0)