Skip to content

Commit 5cc81aa

Browse files
save model in global rank 0 in multinode (#1357)
* save model in global rank 0 in multinode * set_epoch only when training
1 parent abfa4f9 commit 5cc81aa

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

distributed/ddp-tutorial-series/multinode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _save_snapshot(self, epoch):
7171
def train(self, max_epochs: int):
7272
for epoch in range(self.epochs_run, max_epochs):
7373
self._run_epoch(epoch)
74-
if self.local_rank == 0 and epoch % self.save_every == 0:
74+
if self.global_rank == 0 and epoch % self.save_every == 0:
7575
self._save_snapshot(epoch)
7676

7777

distributed/minGPT-ddp/mingpt/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def _run_batch(self, source, targets, train: bool = True) -> float:
111111
return loss.item()
112112

113113
def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
114-
dataloader.sampler.set_epoch(epoch)
114+
if train:
115+
dataloader.sampler.set_epoch(epoch)
115116
for iter, (source, targets) in enumerate(dataloader):
116117
step_type = "Train" if train else "Eval"
117118
source = source.to(self.local_rank)

0 commit comments

Comments
 (0)