Skip to content

Commit 941144a

Browse files
committed
Fix broken data.pth issue.
When saving the checkpoint, if the program crash, the previous version will leave a broken data.pth. The updated code will generate a complete checkpoint first and rename it to `data.pth`.
1 parent 402ba75 commit 941144a

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

qmb/common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def save(self, data: typing.Any, step: int) -> None:
7777
"""
7878
Save data to checkpoint.
7979
"""
80+
data_pth = self.folder() / "data.pth"
81+
local_data_pth = self.folder() / f"data.{step}.pth"
82+
torch.save(data, local_data_pth)
83+
data_pth.unlink(missing_ok=True)
8084
if step % self.checkpoint_interval == 0:
81-
(self.folder() / "data.pth").unlink(missing_ok=True)
82-
torch.save(data, self.folder() / f"data.{step}.pth")
83-
(self.folder() / "data.pth").symlink_to(f"data.{step}.pth")
85+
data_pth.symlink_to(f"data.{step}.pth")
8486
else:
85-
(self.folder() / "data.pth").unlink(missing_ok=True)
86-
torch.save(data, self.folder() / "data.pth")
87+
local_data_pth.rename(data_pth)
8788
if self.max_relative_step is not None:
8889
self.max_absolute_step = step + self.max_relative_step - 1
8990
self.max_relative_step = None

0 commit comments

Comments
 (0)