Skip to content

Commit 241318c

Browse files
authored
ema on cpu (#119)
ema on cpu
1 parent 468dfbb commit 241318c

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from functools import wraps, partial
44
from dataclasses import asdict
5+
from contextlib import contextmanager
56
from pathlib import Path
67

78
from alphafold3_pytorch.alphafold3 import Alphafold3
@@ -64,6 +65,22 @@ def divisible_by(num, den):
6465
def at_most_one_of(*flags: bool) -> bool:
6566
return sum([*map(int, flags)]) <= 1
6667

68+
@contextmanager
69+
def to_device_and_back(
70+
module: Module,
71+
device: torch.device
72+
):
73+
orig_device = next(module.parameters()).device
74+
need_move_device = orig_device != device
75+
76+
if need_move_device:
77+
module.to(device)
78+
79+
yield
80+
81+
if need_move_device:
82+
module.to(orig_device)
83+
6784
def cycle(dataloader: DataLoader):
6885
while True:
6986
for batch in dataloader:
@@ -284,6 +301,7 @@ def __init__(
284301
ema_kwargs: dict = dict(
285302
use_foreach = True
286303
),
304+
ema_on_cpu = False,
287305
use_adam_atan2: bool = False,
288306
use_lion: bool = False,
289307
use_torch_compile: bool = False
@@ -314,9 +332,13 @@ def __init__(
314332
model,
315333
beta = ema_decay,
316334
include_online_model = False,
335+
allow_different_devices = True,
317336
**ema_kwargs
318337
)
319338

339+
self.ema_device = 'cpu' if ema_on_cpu else self.device
340+
self.ema_model.to(self.ema_device)
341+
320342
# maybe torch compile
321343

322344
if use_torch_compile:
@@ -437,6 +459,10 @@ def __init__(
437459
self.last_loaded_train_id = None
438460
self.model_loaded_from_path: Path | None = None
439461

462+
@property
463+
def device(self):
464+
return self.fabric.device
465+
440466
@property
441467
def is_main(self):
442468
return self.fabric.global_rank == 0
@@ -656,7 +682,7 @@ def __call__(
656682
):
657683
eval_model = default(self.ema_model, self.model)
658684

659-
with torch.no_grad():
685+
with torch.no_grad(), to_device_and_back(eval_model, self.device):
660686
eval_model.eval()
661687

662688
total_valid_loss = 0.
@@ -696,7 +722,7 @@ def __call__(
696722
if self.is_main and self.needs_test:
697723
eval_model = default(self.ema_model, self.model)
698724

699-
with torch.no_grad():
725+
with torch.no_grad(), to_device_and_back(eval_model, self.device):
700726
eval_model.eval()
701727

702728
total_test_loss = 0.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.73"
3+
version = "0.2.74"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)