Skip to content

Commit 9e3f5fe

Browse files
committed
add ability to specify custom atom list
1 parent b235274 commit 9e3f5fe

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ class MoleculeInput:
761761
directed_bonds: bool = False
762762
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
763763
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
764-
764+
custom_atoms: List[str]| None = None
765765

766766
@typecheck
767767
def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
@@ -803,7 +803,9 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
803803
atom_ids = None
804804

805805
if i.add_atom_ids:
806-
atom_index = {symbol: i for i, symbol in enumerate(ATOMS)}
806+
atom_list = default(i.custom_atoms, ATOMS)
807+
808+
atom_index = {symbol: i for i, symbol in enumerate(atom_list)}
807809

808810
atom_ids = []
809811

@@ -1101,6 +1103,7 @@ class MoleculeLengthMoleculeInput:
11011103
directed_bonds: bool = False
11021104
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
11031105
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
1106+
custom_atoms: List[str]| None = None
11041107

11051108

11061109
@typecheck
@@ -1264,7 +1267,9 @@ def molecule_lengthed_molecule_input_to_atom_input(
12641267
atom_ids = None
12651268

12661269
if i.add_atom_ids:
1267-
atom_index = {symbol: i for i, symbol in enumerate(ATOMS)}
1270+
atom_list = default(i.custom_atoms, ATOMS)
1271+
1272+
atom_index = {symbol: i for i, symbol in enumerate(atom_list)}
12681273

12691274
atom_ids = []
12701275

@@ -1547,7 +1552,7 @@ class Alphafold3Input:
15471552
directed_bonds: bool = False
15481553
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
15491554
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
1550-
1555+
custom_atoms: List[str] | None = None
15511556

15521557
@typecheck
15531558
def map_int_or_string_indices_to_mol(
@@ -1994,6 +1999,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
19941999
directed_bonds=i.directed_bonds,
19952000
extract_atom_feats_fn=i.extract_atom_feats_fn,
19962001
extract_atompair_feats_fn=i.extract_atompair_feats_fn,
2002+
custom_atoms=i.custom_atoms
19972003
)
19982004

19992005
return molecule_input
@@ -3975,6 +3981,7 @@ def pdb_input_to_molecule_input(
39753981
directed_bonds=i.directed_bonds,
39763982
extract_atom_feats_fn=i.extract_atom_feats_fn,
39773983
extract_atompair_feats_fn=i.extract_atompair_feats_fn,
3984+
custom_atoms=i.custom_atoms
39783985
)
39793986

39803987
return molecule_input

tests/test_input.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from alphafold3_pytorch.data.data_pipeline import *
2222
from alphafold3_pytorch.data.data_pipeline import make_mmcif_features
23+
2324
from alphafold3_pytorch.common.biomolecule import (
2425
Biomolecule,
2526
_from_mmcif_object,
@@ -33,6 +34,7 @@
3334
from alphafold3_pytorch.data import mmcif_writing, mmcif_parsing
3435

3536
from alphafold3_pytorch.life import (
37+
ATOMS,
3638
reverse_complement,
3739
reverse_complement_tensor
3840
)
@@ -84,6 +86,8 @@ def test_alphafold3_input(
8486
directed_bonds
8587
):
8688

89+
CUSTOM_ATOMS = list({*ATOMS, 'Na', 'Fe', 'Si', 'F', 'K'})
90+
8791
alphafold3_input = Alphafold3Input(
8892
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
8993
ds_dna = ['ACGTT'],
@@ -95,7 +99,8 @@ def test_alphafold3_input(
9599
ligands = ['CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=CC(=N4)C5=CN=CC=C5'],
96100
add_atom_ids = True,
97101
add_atompair_ids = True,
98-
directed_bonds = directed_bonds
102+
directed_bonds = directed_bonds,
103+
custom_atoms = CUSTOM_ATOMS
99104
)
100105

101106
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input)
@@ -107,7 +112,7 @@ def test_alphafold3_input(
107112
alphafold3 = Alphafold3(
108113
dim_atom_inputs = 3,
109114
dim_atompair_inputs = 5,
110-
num_atom_embeds = 47,
115+
num_atom_embeds = len(CUSTOM_ATOMS),
111116
num_atompair_embeds = num_atom_bond_types + 1, # 0 is for no bond
112117
atoms_per_window = 27,
113118
dim_template_feats = 108,
@@ -187,17 +192,17 @@ def test_alphafold3_input_to_mmcif(tmp_path):
187192
"""Test the Inference I/O Pipeline. This codifies the data_pipeline.py file used for training."""
188193

189194
alphafold3_input = Alphafold3Input(
190-
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
191-
ds_dna = ['ACGTT'],
192-
ds_rna = ['GCCAU', 'CCAGU'],
193-
ss_dna = ['GCCTA'],
194-
ss_rna = ['CGCAUA'],
195-
metal_ions = ['Na', 'Na', 'Fe'],
196-
misc_molecule_ids = ['Phospholipid'],
197-
ligands = ['CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=CC(=N4)C5=CN=CC=C5'],
198-
add_atom_ids = True,
199-
add_atompair_ids = True,
200-
directed_bonds = True
195+
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
196+
ds_dna = ['ACGTT'],
197+
ds_rna = ['GCCAU', 'CCAGU'],
198+
ss_dna = ['GCCTA'],
199+
ss_rna = ['CGCAUA'],
200+
metal_ions = ['Na', 'Na', 'Fe'],
201+
misc_molecule_ids = ['Phospholipid'],
202+
ligands = ['CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=CC(=N4)C5=CN=CC=C5'],
203+
add_atom_ids = True,
204+
add_atompair_ids = True,
205+
directed_bonds = True
201206
)
202207

203208
test_biomol = alphafold3_input_to_biomolecule(alphafold3_input, atom_positions=torch.randn(261, 47, 3).numpy())
@@ -328,7 +333,7 @@ def test_atompos_input():
328333
atom_encoder_depth = 1,
329334
token_transformer_depth = 1,
330335
atom_decoder_depth = 1,
331-
)
336+
),
332337
)
333338

334339
loss = alphafold3(**batched_atom_input.model_forward_dict())

0 commit comments

Comments
 (0)