|
10 | 10 |
|
11 | 11 | import torch |
12 | 12 | from torch import nn |
13 | | -from torch import Tensor, tensor |
| 13 | +from torch import Tensor, tensor, is_tensor |
14 | 14 | from torch.amp import autocast |
15 | 15 | import torch.nn.functional as F |
| 16 | +from torch.utils._pytree import tree_map |
16 | 17 |
|
17 | 18 | from torch.nn import ( |
18 | 19 | Module, |
@@ -227,6 +228,9 @@ def freeze_(m: Module): |
227 | 228 | def max_neg_value(t: Tensor): |
228 | 229 | return -torch.finfo(t.dtype).max |
229 | 230 |
|
| 231 | +def dict_to_device(d, device): |
| 232 | + return tree_map(lambda t: t.to(device) if is_tensor(t) else t, d) |
| 233 | + |
230 | 234 | def pack_one(t, pattern): |
231 | 235 | packed, ps = pack([t], pattern) |
232 | 236 |
|
@@ -263,7 +267,7 @@ def should_checkpoint( |
263 | 267 | inputs: Tensor | Tuple[Tensor, ...], |
264 | 268 | check_instance_variable: str | None = 'checkpoint' |
265 | 269 | ) -> bool: |
266 | | - if torch.is_tensor(inputs): |
| 270 | + if is_tensor(inputs): |
267 | 271 | inputs = (inputs,) |
268 | 272 |
|
269 | 273 | return ( |
@@ -6344,7 +6348,11 @@ def forward_with_alphafold3_inputs( |
6344 | 6348 | alphafold3_inputs = [alphafold3_inputs] |
6345 | 6349 |
|
6346 | 6350 | batched_atom_inputs = alphafold3_inputs_to_batched_atom_input(alphafold3_inputs, atoms_per_window = self.w) |
6347 | | - return self.forward(**batched_atom_inputs.model_forward_dict(), **kwargs) |
| 6351 | + |
| 6352 | + atom_dict = batched_atom_inputs.model_forward_dict() |
| 6353 | + atom_dict = dict_to_device(atom_dict, device = self.device) |
| 6354 | + |
| 6355 | + return self.forward(**atom_dict, **kwargs) |
6348 | 6356 |
|
6349 | 6357 | @typecheck |
6350 | 6358 | def forward( |
|
0 commit comments