Skip to content

Commit a2edb9e

Browse files
committed
Fix deadlock in checkpoint plugin
1 parent 56b2a9e commit a2edb9e

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

flair/trainers/plugins/functional/checkpoints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw):
3030
)
3131
model_name = "model_epoch_" + str(epoch) + ".pt"
3232
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
33-
if torch.distributed.is_initialized():
34-
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete
3533

3634
@property
3735
def attach_to_all_processes(self) -> bool:

flair/trainers/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,9 @@ def train_custom(
657657

658658
# forward and backward for batch
659659
for batch_step_no, batch_step in enumerate(batch_steps):
660-
enable_gradient_sync = multi_gpu and batch_step_no == len(batch_steps) - 1
661-
sync_context = self.ddp_model.no_sync() if enable_gradient_sync else contextlib.nullcontext()
662-
with sync_context:
660+
disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
661+
grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext()
662+
with grad_sync:
663663
# forward pass
664664
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
665665
if multi_gpu:
@@ -691,7 +691,8 @@ def wrapped_forward_loss(*args, **kwargs2):
691691

692692
# do the optimizer step
693693
if multi_gpu:
694-
self._scale_gradients(torch.distributed.get_world_size()) # DDP averages across processes but we want the sum
694+
# DDP averages across processes but we want the sum
695+
self._scale_gradients(torch.distributed.get_world_size())
695696
scaler.unscale_(self.optimizer)
696697
if max_grad_norm is not None:
697698
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

0 commit comments

Comments
 (0)