Skip to content

Commit 85bb181

Browse files
committed
automatically take care of masking out the distogram atom index, if it falls under the missing atom indices
1 parent 2a52b2a commit 85bb181

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
from __future__ import annotations
22

3-
import einx
4-
import json
53
import os
6-
import torch
4+
import json
5+
from functools import wraps, partial
6+
from dataclasses import dataclass, asdict, field
7+
from typing import Type, Literal, Callable, List, Any, Tuple
78

9+
import torch
10+
from torch import tensor
11+
from torch.nn.utils.rnn import pad_sequence
812
import torch.nn.functional as F
913

10-
from dataclasses import dataclass, asdict, field
11-
from functools import wraps, partial
14+
import einx
15+
1216
from loguru import logger
1317
from pdbeccdutils.core import ccd_reader
18+
1419
from rdkit import Chem
1520
from rdkit.Chem.rdchem import Atom, Mol
16-
from torch import tensor
17-
from typing import Type, Literal, Callable, List, Any, Tuple
1821

1922
from alphafold3_pytorch.attention import (
2023
pad_to_length
@@ -292,21 +295,27 @@ def molecule_to_atom_input(
292295
# handle maybe missing atom indices
293296

294297
missing_atom_mask = None
298+
missing_atom_indices = None
295299

296300
if exists(i.missing_atom_indices) and len(i.missing_atom_indices) > 0:
297301

298-
missing_atom_mask = []
302+
missing_atom_indices: List[Int[' _']] = [default(indices, torch.empty((0,), dtype = torch.long)) for indices in i.missing_atom_indices]
303+
304+
missing_atom_mask: List[Bool[' _']] = []
305+
306+
for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, missing_atom_indices):
299307

300-
for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, i.missing_atom_indices):
301308
mol_miss_atom_mask = torch.zeros(num_atoms, dtype = torch.bool)
302309

303-
if exists(mol_missing_atom_indices) and mol_missing_atom_indices.numel() > 0:
310+
if mol_missing_atom_indices.numel() > 0:
304311
mol_miss_atom_mask.scatter_(-1, mol_missing_atom_indices, True)
305312

306313
missing_atom_mask.append(mol_miss_atom_mask)
307314

308315
missing_atom_mask = torch.cat(missing_atom_mask)
309316

317+
missing_atom_indices = pad_sequence(missing_atom_indices, batch_first = True, padding_value = -1)
318+
310319
# handle maybe atompair embeds
311320

312321
atompair_ids = None
@@ -425,13 +434,27 @@ def molecule_to_atom_input(
425434
row_col_slice = slice(offset, offset + num_atoms)
426435
atompair_inputs[row_col_slice, row_col_slice] = atompair_feat
427436

437+
# mask out molecule atom indices and distogram atom indices where it is in the missing atom indices list
438+
439+
molecule_atom_indices = i.molecule_atom_indices
440+
distogram_atom_indices = i.distogram_atom_indices
441+
442+
if exists(missing_atom_indices):
443+
is_missing_molecule_atom = einx.equal('n missing, n -> n missing', missing_atom_indices, molecule_atom_indices).any(dim = -1)
444+
is_missing_distogram_atom = einx.equal('n missing, n -> n missing', missing_atom_indices, distogram_atom_indices).any(dim = -1)
445+
446+
molecule_atom_indices = molecule_atom_indices.masked_fill(is_missing_molecule_atom, -1)
447+
distogram_atom_indices = distogram_atom_indices.masked_fill(is_missing_distogram_atom, -1)
448+
428449
# handle atom positions
429450

430451
atom_pos = i.atom_pos
431452

432453
if exists(atom_pos) and isinstance(atom_pos, list):
433454
atom_pos = torch.cat(atom_pos, dim = -2)
434455

456+
# atom input
457+
435458
atom_input = AtomInput(
436459
atom_inputs = atom_inputs_tensor,
437460
atompair_inputs = atompair_inputs,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.130"
3+
version = "0.1.131"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)