Skip to content

Commit 99fd519

Browse files
committed
fix trainer save / load
1 parent 5fe0cbd commit 99fd519

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,16 @@ def save(self, path: str | Path, overwrite = False):
215215
steps = self.steps
216216
)
217217

218-
torch.save(str(path), package)
218+
torch.save(package, str(path))
219219

220220
def load(self, path: str | Path, strict = True):
221221
if isinstance(path, str):
222222
path = Path(path)
223223

224224
assert path.exists()
225225

226+
self.model.load(path)
227+
226228
package = torch.load(str(path))
227229

228230
if 'optimizer' in package:
@@ -233,8 +235,6 @@ def load(self, path: str | Path, strict = True):
233235

234236
self.steps = package.get('steps', 0)
235237

236-
self.model.load_state_dict(package['model'])
237-
238238
# shortcut methods
239239

240240
def wait(self):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.59"
3+
version = "0.0.60"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_trainer():
139139

140140
# saving and loading from trainer
141141

142-
trainer.save('./some/nested/folder2/training')
142+
trainer.save('./some/nested/folder2/training', overwrite = True)
143143
trainer.load('./some/nested/folder2/training')
144144

145145
# also allow for loading Alphafold3 directly from training ckpt

0 commit comments

Comments
 (0)