Skip to content

Commit 2a52b2a

Browse files
committed
add missing_atom_indices to both Alphafold3Input as well as MoleculeInput
1 parent 66c0841 commit 2a52b2a

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class MoleculeInput:
213213
token_bonds: Bool['n n']
214214
molecule_atom_indices: List[int | None] | None = None
215215
distogram_atom_indices: List[int | None] | None = None
216+
missing_atom_indices: List[Int[' _'] | None] | None = None
216217
atom_parent_ids: Int[' m'] | None = None
217218
additional_token_feats: Float[f'n dtf'] | None = None
218219
templates: Float['t n n dt'] | None = None
@@ -288,6 +289,24 @@ def molecule_to_atom_input(
288289
all_num_atoms = tensor([mol.GetNumAtoms() for mol in molecules])
289290
offsets = exclusive_cumsum(all_num_atoms)
290291

292+
# handle maybe missing atom indices
293+
294+
missing_atom_mask = None
295+
296+
if exists(i.missing_atom_indices) and len(i.missing_atom_indices) > 0:
297+
298+
missing_atom_mask = []
299+
300+
for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, i.missing_atom_indices):
301+
mol_miss_atom_mask = torch.zeros(num_atoms, dtype = torch.bool)
302+
303+
if exists(mol_missing_atom_indices) and mol_missing_atom_indices.numel() > 0:
304+
mol_miss_atom_mask.scatter_(-1, mol_missing_atom_indices, True)
305+
306+
missing_atom_mask.append(mol_miss_atom_mask)
307+
308+
missing_atom_mask = torch.cat(missing_atom_mask)
309+
291310
# handle maybe atompair embeds
292311

293312
atompair_ids = None
@@ -420,6 +439,7 @@ def molecule_to_atom_input(
420439
molecule_ids = i.molecule_ids,
421440
molecule_atom_indices = i.molecule_atom_indices,
422441
distogram_atom_indices = i.distogram_atom_indices,
442+
missing_atom_mask = missing_atom_mask,
423443
additional_token_feats = i.additional_token_feats,
424444
additional_molecule_feats = i.additional_molecule_feats,
425445
is_molecule_types = i.is_molecule_types,
@@ -448,6 +468,7 @@ class Alphafold3Input:
448468
ds_dna: List[Int[' _'] | str] = imm_list()
449469
ds_rna: List[Int[' _'] | str] = imm_list()
450470
atom_parent_ids: Int[' m'] | None = None
471+
missing_atom_indices: List[List[int] | None] = imm_list()
451472
additional_token_feats: Float[f'n dtf'] | None = None
452473
templates: Float['t n n dt'] | None = None
453474
msa: Float['s n dm'] | None = None
@@ -844,10 +865,25 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
844865
src_tgt_atom_indices = tensor(src_tgt_atom_indices)
845866
src_tgt_atom_indices = pad_to_len(src_tgt_atom_indices, num_tokens, value = -1, dim = -2)
846867

847-
# todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
868+
# atom positions
848869

849870
atom_pos = i.atom_pos
850871

872+
# handle missing atom indices
873+
874+
missing_atom_indices = None
875+
876+
if exists(i.missing_atom_indices) and len(i.missing_atom_indices) > 0:
877+
missing_atom_indices = []
878+
879+
for mol_miss_atom_indices in i.missing_atom_indices:
880+
mol_miss_atom_indices = default(mol_miss_atom_indices, [])
881+
mol_miss_atom_indices = tensor(mol_miss_atom_indices, dtype = torch.long)
882+
883+
missing_atom_indices.append(mol_miss_atom_indices)
884+
885+
assert len(molecules) == len(missing_atom_indices)
886+
851887
# create molecule input
852888

853889
molecule_input = MoleculeInput(
@@ -860,6 +896,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
860896
additional_molecule_feats = additional_molecule_feats,
861897
additional_token_feats = default(i.additional_token_feats, torch.zeros(num_tokens, 2)),
862898
is_molecule_types = is_molecule_types,
899+
missing_atom_indices = missing_atom_indices,
863900
src_tgt_atom_indices = src_tgt_atom_indices,
864901
atom_pos = atom_pos,
865902
templates = i.templates,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.129"
3+
version = "0.1.130"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_atompos_input():
9292

9393
train_alphafold3_input = Alphafold3Input(
9494
proteins = [contrived_protein],
95+
missing_atom_indices = [[1, 2], None],
9596
atom_pos = mock_atompos
9697
)
9798

0 commit comments

Comments
 (0)