|
2 | 2 |
|
3 | 3 | from functools import wraps, partial |
4 | 4 | from dataclasses import asdict |
| 5 | +from contextlib import contextmanager |
5 | 6 | from pathlib import Path |
6 | 7 |
|
7 | 8 | from alphafold3_pytorch.alphafold3 import Alphafold3 |
@@ -64,6 +65,22 @@ def divisible_by(num, den): |
64 | 65 | def at_most_one_of(*flags: bool) -> bool: |
65 | 66 | return sum([*map(int, flags)]) <= 1 |
66 | 67 |
|
| 68 | +@contextmanager |
| 69 | +def to_device_and_back( |
| 70 | + module: Module, |
| 71 | + device: torch.device |
| 72 | +): |
| 73 | + orig_device = next(module.parameters()).device |
| 74 | + need_move_device = orig_device != device |
| 75 | + |
| 76 | + if need_move_device: |
| 77 | + module.to(device) |
| 78 | + |
| 79 | + yield |
| 80 | + |
| 81 | + if need_move_device: |
| 82 | + module.to(orig_device) |
| 83 | + |
67 | 84 | def cycle(dataloader: DataLoader): |
68 | 85 | while True: |
69 | 86 | for batch in dataloader: |
@@ -284,6 +301,7 @@ def __init__( |
284 | 301 | ema_kwargs: dict = dict( |
285 | 302 | use_foreach = True |
286 | 303 | ), |
| 304 | + ema_on_cpu = False, |
287 | 305 | use_adam_atan2: bool = False, |
288 | 306 | use_lion: bool = False, |
289 | 307 | use_torch_compile: bool = False |
@@ -314,9 +332,13 @@ def __init__( |
314 | 332 | model, |
315 | 333 | beta = ema_decay, |
316 | 334 | include_online_model = False, |
| 335 | + allow_different_devices = True, |
317 | 336 | **ema_kwargs |
318 | 337 | ) |
319 | 338 |
|
| 339 | + self.ema_device = 'cpu' if ema_on_cpu else self.device |
| 340 | + self.ema_model.to(self.ema_device) |
| 341 | + |
320 | 342 | # maybe torch compile |
321 | 343 |
|
322 | 344 | if use_torch_compile: |
@@ -437,6 +459,10 @@ def __init__( |
437 | 459 | self.last_loaded_train_id = None |
438 | 460 | self.model_loaded_from_path: Path | None = None |
439 | 461 |
|
| 462 | + @property |
| 463 | + def device(self): |
| 464 | + return self.fabric.device |
| 465 | + |
440 | 466 | @property |
441 | 467 | def is_main(self): |
442 | 468 | return self.fabric.global_rank == 0 |
@@ -656,7 +682,7 @@ def __call__( |
656 | 682 | ): |
657 | 683 | eval_model = default(self.ema_model, self.model) |
658 | 684 |
|
659 | | - with torch.no_grad(): |
| 685 | + with torch.no_grad(), to_device_and_back(eval_model, self.device): |
660 | 686 | eval_model.eval() |
661 | 687 |
|
662 | 688 | total_valid_loss = 0. |
@@ -696,7 +722,7 @@ def __call__( |
696 | 722 | if self.is_main and self.needs_test: |
697 | 723 | eval_model = default(self.ema_model, self.model) |
698 | 724 |
|
699 | | - with torch.no_grad(): |
| 725 | + with torch.no_grad(), to_device_and_back(eval_model, self.device): |
700 | 726 | eval_model.eval() |
701 | 727 |
|
702 | 728 | total_test_loss = 0. |
|
0 commit comments