Skip to content

Commit 3d9055d

Browse files
committed
ready for fine-tuning
1 parent 419b024 commit 3d9055d

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ def load(
354354
self,
355355
path: str | Path,
356356
strict = True,
357-
prefix = None
357+
prefix = None,
358+
only_model = False
358359
):
359360
if isinstance(path, str):
360361
path = Path(path)
@@ -382,6 +383,11 @@ def load(
382383

383384
self.model.load(path)
384385

386+
if only_model:
387+
return
388+
389+
# load optimizer and scheduler states
390+
385391
package = torch.load(str(path))
386392

387393
if 'optimizer' in package:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.25"
3+
version = "0.1.26"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def test_trainer():
164164
trainer.save('./some/nested/folder2/training.pt', overwrite = True)
165165
trainer.load('./some/nested/folder2/training.pt')
166166

167+
# allow for only loading model, needed for fine-tuning logic
168+
169+
trainer.load('./some/nested/folder2/training.pt', only_model = True)
170+
167171
# also allow for loading Alphafold3 directly from training ckpt
168172

169173
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training.pt')

0 commit comments

Comments
 (0)