Skip to content

Commit ead35b7

Browse files
committed
auto move inputs to same device as model when invoking forward_with_alphafold3_inputs
1 parent 3e1795e commit ead35b7

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
import torch
1212
from torch import nn
13-
from torch import Tensor, tensor
13+
from torch import Tensor, tensor, is_tensor
1414
from torch.amp import autocast
1515
import torch.nn.functional as F
16+
from torch.utils._pytree import tree_map
1617

1718
from torch.nn import (
1819
Module,
@@ -227,6 +228,9 @@ def freeze_(m: Module):
227228
def max_neg_value(t: Tensor):
228229
return -torch.finfo(t.dtype).max
229230

231+
def dict_to_device(d, device):
232+
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, d)
233+
230234
def pack_one(t, pattern):
231235
packed, ps = pack([t], pattern)
232236

@@ -263,7 +267,7 @@ def should_checkpoint(
263267
inputs: Tensor | Tuple[Tensor, ...],
264268
check_instance_variable: str | None = 'checkpoint'
265269
) -> bool:
266-
if torch.is_tensor(inputs):
270+
if is_tensor(inputs):
267271
inputs = (inputs,)
268272

269273
return (
@@ -6344,7 +6348,11 @@ def forward_with_alphafold3_inputs(
63446348
alphafold3_inputs = [alphafold3_inputs]
63456349

63466350
batched_atom_inputs = alphafold3_inputs_to_batched_atom_input(alphafold3_inputs, atoms_per_window = self.w)
6347-
return self.forward(**batched_atom_inputs.model_forward_dict(), **kwargs)
6351+
6352+
atom_dict = batched_atom_inputs.model_forward_dict()
6353+
atom_dict = dict_to_device(atom_dict, device = self.device)
6354+
6355+
return self.forward(**atom_dict, **kwargs)
63486356

63496357
@typecheck
63506358
def forward(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.18"
3+
version = "0.5.19"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)