Skip to content

Commit 585aca0

Browse files
committed
complete saving and loading trainer states
1 parent d80a997 commit 585aca0

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
4+
35
from alphafold3_pytorch.alphafold3 import Alphafold3
46

57
from typing import TypedDict
@@ -187,6 +189,45 @@ def __init__(
187189
def is_main(self):
188190
return self.fabric.global_rank == 0
189191

192+
# saving and loading
193+
194+
def save(self, path: str | Path, overwrite = False):
195+
if isinstance(path, str):
196+
path = Path(path)
197+
198+
assert not path.is_dir() and (not path.exists() or overwrite)
199+
200+
path.parent.mkdir(exist_ok = True, parents = True)
201+
202+
package = dict(
203+
model = self.model.state_dict_with_init_args,
204+
optimizer = self.optimizer.state_dict(),
205+
scheduler = self.scheduler.state_dict(),
206+
steps = self.steps
207+
)
208+
209+
torch.save(str(path), package)
210+
211+
def load(self, path: str | Path, strict = True):
212+
if isinstance(path, str):
213+
path = Path(path)
214+
215+
assert path.exists()
216+
217+
package = torch.load(str(path))
218+
219+
if 'optimizer' in package:
220+
self.optimizer.load_state_dict(package['optimizer'])
221+
222+
if 'scheduler' in package:
223+
self.scheduler.load_state_dict(package['scheduler'])
224+
225+
self.steps = package.get('steps', 0)
226+
227+
self.model.load_state_dict(package['model'])
228+
229+
# shortcut methods
230+
190231
def wait(self):
191232
self.fabric.barrier()
192233

@@ -196,6 +237,8 @@ def print(self, *args, **kwargs):
196237
def log(self, **log_data):
197238
self.fabric.log_dict(log_data, step = self.steps)
198239

240+
# main train forwards
241+
199242
def __call__(
200243
self
201244
):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.57"
3+
version = "0.0.58"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,12 @@ def test_trainer():
134134
)
135135

136136
trainer()
137+
138+
# saving and loading from trainer
139+
140+
trainer.save('./some/nested/folder2/training')
141+
trainer.load('./some/nested/folder2/training')
142+
143+
# also allow for loading Alphafold3 directly from training ckpt
144+
145+
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training')

0 commit comments

Comments
 (0)