Skip to content

Commit c02059d

Browse files
committed
fix ckpt tests
1 parent 232c707 commit c02059d

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
accelerator = 'auto',
113113
checkpoint_every: int = 1000,
114114
checkpoint_folder: str = './checkpoints',
115+
overwrite_checkpoints: bool = False,
115116
fabric_kwargs: dict = dict(),
116117
ema_kwargs: dict = dict()
117118
):
@@ -199,6 +200,7 @@ def __init__(
199200
# checkpointing logic
200201

201202
self.checkpoint_every = checkpoint_every
203+
self.overwrite_checkpoints = overwrite_checkpoints
202204
self.checkpoint_folder = Path(checkpoint_folder)
203205

204206
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
@@ -367,7 +369,9 @@ def __call__(
367369
self.wait()
368370

369371
if self.is_main and divisible_by(self.steps, self.checkpoint_every):
370-
self.save(self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt')
372+
checkpoint_path = self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt'
373+
374+
self.save(checkpoint_path, overwrite = self.overwrite_checkpoints)
371375

372376
self.wait()
373377

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.61"
3+
version = "0.0.62"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
os.environ['TYPECHECK'] = 'True'
33

4+
from pathlib import Path
5+
46
import pytest
57
import torch
68
from torch.utils.data import Dataset, DataLoader
@@ -133,7 +135,8 @@ def test_trainer():
133135
batch_size = 1,
134136
valid_every = 1,
135137
grad_accum_every = 2,
136-
checkpoint_every = 1
138+
checkpoint_every = 1,
139+
overwrite_checkpoints = True
137140
)
138141

139142
trainer()

0 commit comments

Comments
 (0)