Skip to content

Commit dc40483

Browse files
committed
unwrap model and optimizer manually during save and load for fabric
1 parent 3d9055d commit dc40483

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ema_pytorch import EMA
2525

2626
from lightning import Fabric
27+
from lightning.fabric.wrappers import _unwrap_objects
2728

2829
# constants
2930

@@ -341,9 +342,12 @@ def save(
341342

342343
path.parent.mkdir(exist_ok = True, parents = True)
343344

345+
unwrapped_model = _unwrap_objects(self.model)
346+
unwrapped_optimizer = _unwrap_objects(self.optimizer)
347+
344348
package = dict(
345-
model = self.model.state_dict_with_init_args,
346-
optimizer = self.optimizer.state_dict(),
349+
model = unwrapped_model.state_dict_with_init_args,
350+
optimizer = unwrapped_optimizer.state_dict(),
347351
scheduler = self.scheduler.state_dict(),
348352
steps = self.steps
349353
)
@@ -379,9 +383,14 @@ def load(
379383

380384
self.model_loaded_from_path = path
381385

386+
# get unwrapped model and optimizer
387+
388+
unwrapped_model = _unwrap_objects(self.model)
389+
unwrapped_optimizer = _unwrap_objects(self.optimizer)
390+
382391
# load model from path
383392

384-
self.model.load(path)
393+
unwrapped_model.load(path)
385394

386395
if only_model:
387396
return
@@ -391,7 +400,7 @@ def load(
391400
package = torch.load(str(path))
392401

393402
if 'optimizer' in package:
394-
self.optimizer.load_state_dict(package['optimizer'])
403+
unwrapped_optimizer.load_state_dict(package['optimizer'])
395404

396405
if 'scheduler' in package:
397406
self.scheduler.load_state_dict(package['scheduler'])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.26"
3+
version = "0.1.27"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)