Skip to content

Commit 56b2a9e

Browse files
committed
fix: Sum learning rate instead of averaging on multi gpu sync
1 parent 00b0d36 commit 56b2a9e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

flair/trainers/trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,9 @@ def train_custom(
657657

658658
# forward and backward for batch
659659
for batch_step_no, batch_step in enumerate(batch_steps):
660-
skip_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
661-
gradient_sync = contextlib.nullcontext() if skip_sync else self.ddp_model.no_sync()
662-
with gradient_sync:
660+
enable_gradient_sync = multi_gpu and batch_step_no == len(batch_steps) - 1
661+
sync_context = self.ddp_model.no_sync() if enable_gradient_sync else contextlib.nullcontext()
662+
with sync_context:
663663
# forward pass
664664
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
665665
if multi_gpu:
@@ -690,6 +690,8 @@ def wrapped_forward_loss(*args, **kwargs2):
690690
self.dispatch("before_training_optimizer_step", **batch_kw)
691691

692692
# do the optimizer step
693+
if multi_gpu:
694+
self._scale_gradients(torch.distributed.get_world_size()) # DDP averages across processes but we want the sum
693695
scaler.unscale_(self.optimizer)
694696
if max_grad_norm is not None:
695697
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
@@ -988,3 +990,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
988990
self.model.save(model_file, checkpoint)
989991
if torch.distributed.is_initialized():
990992
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
993+
994+
def _scale_gradients(self, constant):
995+
for param in self.model.parameters():
996+
if param.grad is not None:
997+
param.grad.data.mul_(constant)

0 commit comments

Comments
 (0)