|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import einx |
4 | | -import json |
5 | 3 | import os |
6 | | -import torch |
| 4 | +import json |
| 5 | +from functools import wraps, partial |
| 6 | +from dataclasses import dataclass, asdict, field |
| 7 | +from typing import Type, Literal, Callable, List, Any, Tuple |
7 | 8 |
|
| 9 | +import torch |
| 10 | +from torch import tensor |
| 11 | +from torch.nn.utils.rnn import pad_sequence |
8 | 12 | import torch.nn.functional as F |
9 | 13 |
|
10 | | -from dataclasses import dataclass, asdict, field |
11 | | -from functools import wraps, partial |
| 14 | +import einx |
| 15 | + |
12 | 16 | from loguru import logger |
13 | 17 | from pdbeccdutils.core import ccd_reader |
| 18 | + |
14 | 19 | from rdkit import Chem |
15 | 20 | from rdkit.Chem.rdchem import Atom, Mol |
16 | | -from torch import tensor |
17 | | -from typing import Type, Literal, Callable, List, Any, Tuple |
18 | 21 |
|
19 | 22 | from alphafold3_pytorch.attention import ( |
20 | 23 | pad_to_length |
@@ -292,21 +295,27 @@ def molecule_to_atom_input( |
292 | 295 | # handle maybe missing atom indices |
293 | 296 |
|
294 | 297 | missing_atom_mask = None |
| 298 | + missing_atom_indices = None |
295 | 299 |
|
296 | 300 | if exists(i.missing_atom_indices) and len(i.missing_atom_indices) > 0: |
297 | 301 |
|
298 | | - missing_atom_mask = [] |
| 302 | + missing_atom_indices: List[Int[' _']] = [default(indices, torch.empty((0,), dtype = torch.long)) for indices in i.missing_atom_indices] |
| 303 | + |
| 304 | + missing_atom_mask: List[Bool[' _']] = [] |
| 305 | + |
| 306 | + for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, missing_atom_indices): |
299 | 307 |
|
300 | | - for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, i.missing_atom_indices): |
301 | 308 | mol_miss_atom_mask = torch.zeros(num_atoms, dtype = torch.bool) |
302 | 309 |
|
303 | | - if exists(mol_missing_atom_indices) and mol_missing_atom_indices.numel() > 0: |
| 310 | + if mol_missing_atom_indices.numel() > 0: |
304 | 311 | mol_miss_atom_mask.scatter_(-1, mol_missing_atom_indices, True) |
305 | 312 |
|
306 | 313 | missing_atom_mask.append(mol_miss_atom_mask) |
307 | 314 |
|
308 | 315 | missing_atom_mask = torch.cat(missing_atom_mask) |
309 | 316 |
|
| 317 | + missing_atom_indices = pad_sequence(missing_atom_indices, batch_first = True, padding_value = -1) |
| 318 | + |
310 | 319 | # handle maybe atompair embeds |
311 | 320 |
|
312 | 321 | atompair_ids = None |
@@ -425,13 +434,27 @@ def molecule_to_atom_input( |
425 | 434 | row_col_slice = slice(offset, offset + num_atoms) |
426 | 435 | atompair_inputs[row_col_slice, row_col_slice] = atompair_feat |
427 | 436 |
|
| 437 | + # mask out molecule atom indices and distogram atom indices where it is in the missing atom indices list |
| 438 | + |
| 439 | + molecule_atom_indices = i.molecule_atom_indices |
| 440 | + distogram_atom_indices = i.distogram_atom_indices |
| 441 | + |
| 442 | + if exists(missing_atom_indices): |
| 443 | + is_missing_molecule_atom = einx.equal('n missing, n -> n missing', missing_atom_indices, molecule_atom_indices).any(dim = -1) |
| 444 | + is_missing_distogram_atom = einx.equal('n missing, n -> n missing', missing_atom_indices, distogram_atom_indices).any(dim = -1) |
| 445 | + |
| 446 | + molecule_atom_indices = molecule_atom_indices.masked_fill(is_missing_molecule_atom, -1) |
| 447 | + distogram_atom_indices = distogram_atom_indices.masked_fill(is_missing_distogram_atom, -1) |
| 448 | + |
428 | 449 | # handle atom positions |
429 | 450 |
|
430 | 451 | atom_pos = i.atom_pos |
431 | 452 |
|
432 | 453 | if exists(atom_pos) and isinstance(atom_pos, list): |
433 | 454 | atom_pos = torch.cat(atom_pos, dim = -2) |
434 | 455 |
|
| 456 | + # atom input |
| 457 | + |
435 | 458 | atom_input = AtomInput( |
436 | 459 | atom_inputs = atom_inputs_tensor, |
437 | 460 | atompair_inputs = atompair_inputs, |
|
0 commit comments