Skip to content

Commit 7c664d3

Browse files
committed
handle a directory being passed into Trainer.load
1 parent 3231cab commit 7c664d3

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ def __init__(
316316
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
317317
assert self.checkpoint_folder.is_dir()
318318

319+
# save the path for the last loaded model, if any
320+
321+
self.model_loaded_from_path = None
322+
319323
@property
320324
def is_main(self):
321325
return self.fabric.global_rank == 0
@@ -345,6 +349,23 @@ def load(self, path: str | Path, strict = True):
345349

346350
assert path.exists()
347351

352+
# if the path is a directory, then automatically load latest checkpoint
353+
354+
if path.is_dir():
355+
model_paths = [*path.glob('**/*.pt')]
356+
357+
assert len(model_paths) > 0, f'no files found in directory {path}'
358+
359+
model_paths = sorted(model_paths, key = lambda p: int(str(p).split('.')[-2]))
360+
361+
path = model_paths[-1]
362+
363+
# for eventually saving entire training history in filename
364+
365+
self.model_loaded_from_path = path
366+
367+
# load model from path
368+
348369
self.model.load(path)
349370

350371
package = torch.load(str(path))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.17"
3+
version = "0.1.18"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_trainer():
9494
depth = 1
9595
),
9696
pairformer_stack = dict(
97-
depth = 2
97+
depth = 1
9898
),
9999
diffusion_module_kwargs = dict(
100100
atom_encoder_depth = 1,
@@ -147,13 +147,21 @@ def test_trainer():
147147

148148
trainer()
149149

150+
# assert checkpoints created
151+
150152
assert Path('./checkpoints/af3.ckpt.1.pt').exists()
151153

154+
# assert can load latest checkpoint by loading from a directory
155+
156+
trainer.load('./checkpoints')
157+
158+
assert str(trainer.model_loaded_from_path) == str(Path('./checkpoints/af3.ckpt.2.pt'))
159+
152160
# saving and loading from trainer
153161

154-
trainer.save('./some/nested/folder2/training', overwrite = True)
155-
trainer.load('./some/nested/folder2/training')
162+
trainer.save('./some/nested/folder2/training.pt', overwrite = True)
163+
trainer.load('./some/nested/folder2/training.pt')
156164

157165
# also allow for loading Alphafold3 directly from training ckpt
158166

159-
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training')
167+
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training.pt')

0 commit comments

Comments
 (0)