Skip to content

Commit 232c707

Browse files
committed
checkpointing
1 parent 99fd519 commit 232c707

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def __init__(
110110
default_lambda_lr = default_lambda_lr_fn,
111111
fabric: Fabric | None = None,
112112
accelerator = 'auto',
113+
checkpoint_every: int = 1000,
114+
checkpoint_folder: str = './checkpoints',
113115
fabric_kwargs: dict = dict(),
114116
ema_kwargs: dict = dict()
115117
):
@@ -194,6 +196,14 @@ def __init__(
194196

195197
self.steps = 0
196198

199+
# checkpointing logic
200+
201+
self.checkpoint_every = checkpoint_every
202+
self.checkpoint_folder = Path(checkpoint_folder)
203+
204+
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
205+
assert self.checkpoint_folder.is_dir()
206+
197207
@property
198208
def is_main(self):
199209
return self.fabric.global_rank == 0
@@ -356,6 +366,11 @@ def __call__(
356366

357367
self.wait()
358368

369+
if self.is_main and divisible_by(self.steps, self.checkpoint_every):
370+
self.save(self.checkpoint_folder / f'af3.ckpt.{self.steps}.pt')
371+
372+
self.wait()
373+
359374
# maybe test
360375

361376
if self.is_main and self.needs_test:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.60"
3+
version = "0.0.61"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,14 @@ def test_trainer():
132132
num_train_steps = 2,
133133
batch_size = 1,
134134
valid_every = 1,
135-
grad_accum_every = 2
135+
grad_accum_every = 2,
136+
checkpoint_every = 1
136137
)
137138

138139
trainer()
139140

141+
assert Path('./checkpoints/af3.ckpt.1.pt').exists()
142+
140143
# saving and loading from trainer
141144

142145
trainer.save('./some/nested/folder2/training', overwrite = True)

0 commit comments

Comments
 (0)