Skip to content

Commit 099279a

Browse files
authored
Add various bug fixes and optimizations for pdb_input_to_molecule_input() (#136)
* Update amino_acid_constants.py * Update dna_constants.py * Update rna_constants.py * Update test_af3.py * Update data_utils.py * Update amino_acid_constants.py * Update dna_constants.py * Update rna_constants.py * Update inputs.py * Update inputs.py
1 parent 92a9011 commit 099279a

File tree

6 files changed

+82
-34
lines changed

6 files changed

+82
-34
lines changed

alphafold3_pytorch/common/amino_acid_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"_",
5757
"_", # 10 null types.
5858
]
59+
element_types = [atom_type[0] for atom_type in atom_types]
5960
atom_types_set = set(atom_types)
6061
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6162
atom_type_num = len(atom_types) # := 37 + 10 null types := 47.

alphafold3_pytorch/common/dna_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"_",
6060
"_", # 19 null types.
6161
]
62+
element_types = [atom_type[0] for atom_type in atom_types]
6263
atom_types_set = set(atom_types)
6364
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6465
atom_type_num = len(atom_types) # := 28 + 19 null types := 47.

alphafold3_pytorch/common/rna_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"_",
6060
"_", # 19 null types.
6161
]
62+
element_types = [atom_type[0] for atom_type in atom_types]
6263
atom_types_set = set(atom_types)
6364
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
6465
atom_type_num = len(atom_types) # := 28 + 19 null types := 47.

alphafold3_pytorch/inputs.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial, wraps
88
from itertools import groupby
99
from collections import defaultdict
10-
from collections.abc import Iterable
10+
from collections.abc import Iterableassignment
1111
from dataclasses import asdict, dataclass, field
1212
from typing import Any, Callable, Dict, List, Literal, Set, Tuple, Type
1313

@@ -28,7 +28,7 @@
2828

2929
from pdbeccdutils.core import ccd_reader
3030

31-
from rdkit import Chem
31+
from rdkit import Chem, RDLogger
3232
from rdkit.Chem import AllChem, rdDetermineBonds
3333
from rdkit.Chem.rdchem import Atom, Mol
3434
from rdkit.Geometry import Point3D
@@ -61,10 +61,15 @@
6161
get_pdb_input_residue_molecule_type,
6262
is_atomized_residue,
6363
is_polymer,
64+
remove_last_digit_character,
6465
)
6566
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
6667
from alphafold3_pytorch.utils.utils import default, exists, first, identity
6768

69+
# silence RDKit's warnings
70+
71+
RDLogger.DisableLog("rdApp.*")
72+
6873
# constants
6974

7075
IS_MOLECULE_TYPES = 5
@@ -84,7 +89,7 @@
8489
MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
8590
NUM_MOLECULE_IDS = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 2
8691

87-
DEFAULT_NUM_MOLECULE_MODS = 5
92+
DEFAULT_NUM_MOLECULE_MODS = 4 # `mod_protein`, `mod_rna`, `mod_dna`, and `mod_unk`
8893
ADDITIONAL_MOLECULE_FEATS = 5
8994

9095
CCD_COMPONENTS_FILEPATH = os.path.join('data', 'ccd_data', 'components.cif')
@@ -669,7 +674,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
669674
atom_input = AtomInput(
670675
atom_inputs=atom_inputs_tensor,
671676
atompair_inputs=atompair_inputs,
672-
molecule_atom_lens=tensor(atom_lens, dtype=torch.long),
677+
molecule_atom_lens=atom_lens.long(),
673678
molecule_ids=i.molecule_ids,
674679
molecule_atom_indices=i.molecule_atom_indices,
675680
distogram_atom_indices=i.distogram_atom_indices,
@@ -1750,20 +1755,24 @@ def add_atom_positions_to_mol(
17501755

17511756

17521757
def create_mol_from_atom_positions_and_types(
1758+
name: str,
17531759
atom_positions: np.ndarray,
17541760
element_types: List[str],
17551761
missing_atom_indices: Set[int],
17561762
num_bond_attempts: int = 2,
1763+
verbose: bool = False,
17571764
) -> Mol:
17581765
"""Create an RDKit molecule from a NumPy array of atom positions and a list of their element
17591766
types.
17601767
1768+
:param name: The name of the molecule.
17611769
:param atom_positions: A NumPy array of shape (num_atoms, 3) containing the 3D coordinates of
17621770
each atom.
17631771
:param element_types: A list of element symbols for each atom in the molecule.
17641772
:param missing_atom_indices: A set of atom indices that are missing from the atom_positions
17651773
array.
17661774
:param num_bond_attempts: The number of attempts to determine the bonds in the molecule.
1775+
:param verbose: Whether to log warnings when bond determination fails.
17671776
:return: An RDKit molecule with the specified atom positions and element types.
17681777
"""
17691778
if len(atom_positions) != len(element_types):
@@ -1772,6 +1781,7 @@ def create_mol_from_atom_positions_and_types(
17721781
# populate an empty editable molecule
17731782

17741783
mol = Chem.RWMol()
1784+
mol.SetProp("_Name", name)
17751785

17761786
for element_type in element_types:
17771787
atom = Chem.Atom(element_type)
@@ -1789,18 +1799,23 @@ def create_mol_from_atom_positions_and_types(
17891799
# finalize molecule by inferring bonds
17901800

17911801
determined_bonds = False
1792-
for _ in range(num_bond_attempts):
1802+
for i in range(num_bond_attempts):
17931803
try:
1794-
rdDetermineBonds.DetermineBonds(mol, charge=Chem.GetFormalCharge(mol))
1804+
charge = Chem.GetFormalCharge(mol)
1805+
rdDetermineBonds.DetermineBonds(mol, charge=charge)
17951806
determined_bonds = True
17961807
except Exception as e:
1797-
logger.warning(
1798-
f"Failed to determine bonds in the input molecule due to: {e}. "
1799-
"Retrying once more."
1800-
)
1808+
if verbose:
1809+
logger.warning(
1810+
f"Failed to determine bonds for the input molecule {name} due to: {e}. "
1811+
f"{'Retrying once more.' if i < num_bond_attempts - 1 else 'Terminating bond assignment.'}"
1812+
)
18011813
continue
18021814
if not determined_bonds:
1803-
raise ValueError("Failed to determine bonds in the input molecule.")
1815+
if verbose:
1816+
logger.warning(
1817+
"Failed to determine bonds in the input molecule. Skipping bond assignment."
1818+
)
18041819

18051820
mol = Chem.RemoveHs(mol)
18061821
Chem.SanitizeMol(mol)
@@ -1818,6 +1833,7 @@ def extract_template_molecules_from_biomolecule_chains(
18181833
chain_seqs: List[str],
18191834
chain_chem_types: List[PDB_INPUT_RESIDUE_MOLECULE_TYPE],
18201835
mol_keyname: str = "rdchem_mol",
1836+
verbose: bool = False,
18211837
) -> Tuple[List[Mol], List[PDB_INPUT_RESIDUE_MOLECULE_TYPE]]:
18221838
"""Extract RDKit template molecules and their types for the residues of each `Biomolecule`
18231839
chain.
@@ -1902,23 +1918,26 @@ def extract_template_molecules_from_biomolecule_chains(
19021918
res_atom_type_indices = np.where(res_atom_positions.all(axis=-1))[1]
19031919
res_atom_elements = [
19041920
# NOTE: here, we treat the first character of each atom type as its element symbol
1905-
res_constants.element_types[idx]
1921+
res_constants.element_types[idx].replace("ATM", "*")
19061922
for idx in res_atom_type_indices
19071923
]
19081924
mol = create_mol_from_atom_positions_and_types(
19091925
# NOTE: for now, we construct molecules without referencing canonical
19101926
# SMILES strings, which means there are no missing molecule atoms by design
1911-
res_atom_positions[res_atom_mask],
1912-
res_atom_elements,
1927+
name=seq,
1928+
atom_positions=res_atom_positions[res_atom_mask],
1929+
element_types=res_atom_elements,
19131930
missing_atom_indices=set(),
1931+
verbose=verbose,
19141932
)
19151933
try:
19161934
mol = AllChem.AssignBondOrdersFromTemplate(template_mol, mol)
19171935
except Exception as e:
1918-
logger.warning(
1919-
f"Failed to assign bond orders from the template ligand molecule for residue {res} due to: {e}. "
1920-
"Skipping bond order assignment."
1921-
)
1936+
if verbose:
1937+
logger.warning(
1938+
f"Failed to assign bond orders from the template atomized molecule for residue {seq} due to: {e}. "
1939+
"Skipping bond order assignment."
1940+
)
19221941
res_index += mol.GetNumAtoms()
19231942

19241943
# (Unmodified) polymer residues
@@ -1977,6 +1996,7 @@ def extract_template_molecules_from_biomolecule_chains(
19771996
res_atom_positions.reshape(-1, 3),
19781997
missing_atom_indices,
19791998
)
1999+
mol.SetProp("_Name", res)
19802000
res_index += 1
19812001

19822002
mol_seq.append(mol)
@@ -2079,7 +2099,9 @@ def find_mismatched_symmetry(
20792099

20802100
@typecheck
20812101
def pdb_input_to_molecule_input(
2082-
pdb_input: PDBInput, biomol: Biomolecule | None = None
2102+
pdb_input: PDBInput,
2103+
biomol: Biomolecule | None = None,
2104+
verbose: bool = False,
20832105
) -> MoleculeInput:
20842106
"""Convert a PDBInput to a MoleculeInput."""
20852107
i = pdb_input
@@ -2173,16 +2195,19 @@ def pdb_input_to_molecule_input(
21732195
biomol,
21742196
chain_seqs,
21752197
chain_chem_types,
2198+
verbose=verbose,
21762199
)
21772200

21782201
# collect pooling lengths and atom-wise molecule types for each molecule,
2179-
# along with a token-wise boolean tensor indicating whether each molecule is modified
2202+
# along with a token-wise one-hot tensor indicating whether each molecule is modified
2203+
# and, if so, which type of modification it has (e.g., peptide vs. RNA modification)
21802204
molecule_idx = 0
21812205
token_pool_lens = []
21822206
molecule_atom_types = []
21832207
is_molecule_mod = []
21842208
for mol, mol_type in zip(molecules, molecule_types):
21852209
num_atoms = mol.GetNumAtoms()
2210+
is_mol_mod_type = [False for _ in range(DEFAULT_NUM_MOLECULE_MODS)]
21862211
if is_atomized_residue(mol_type):
21872212
# NOTE: in the paper, they treat each atom of the ligand and modified polymer residues as a token
21882213
token_pool_lens.extend([1] * num_atoms)
@@ -2201,26 +2226,30 @@ def pdb_input_to_molecule_input(
22012226
is_mol_type_index = IS_LIGAND_INDEX
22022227
elif mol_type == "mod_protein":
22032228
is_mol_type_index = IS_PROTEIN_INDEX
2229+
is_mol_mod_type_index = 0
22042230
elif mol_type == "mod_rna":
22052231
is_mol_type_index = IS_RNA_INDEX
2232+
is_mol_mod_type_index = 1
22062233
elif mol_type == "mod_dna":
22072234
is_mol_type_index = IS_DNA_INDEX
2235+
is_mol_mod_type_index = 2
22082236
else:
22092237
raise ValueError(f"Unrecognized molecule type: {mol_type}")
22102238

22112239
is_molecule_types[molecule_type_row_idx, is_mol_type_index] = True
22122240

2213-
is_molecule_mod.extend([True if "mod" in mol_type else False] * num_atoms)
2241+
if "mod" in mol_type:
2242+
is_mol_mod_type[is_mol_mod_type_index] = True
2243+
is_molecule_mod.extend([is_mol_mod_type] * num_atoms)
22142244

22152245
molecule_idx += num_atoms
22162246
else:
22172247
token_pool_lens.append(num_atoms)
22182248
molecule_atom_types.append(mol_type)
2219-
is_molecule_mod.append(False)
2249+
is_molecule_mod.append(is_mol_mod_type)
22202250
molecule_idx += 1
22212251

22222252
# collect token center, distogram, and source-target atom indices for each token
2223-
molecule_idx = 0
22242253
molecule_atom_indices = []
22252254
distogram_atom_indices = []
22262255
src_tgt_atom_indices = []
@@ -2415,35 +2444,39 @@ def pdb_input_to_molecule_input(
24152444
ptnr2_atom_id = (
24162445
f"{bond.ptnr2_auth_asym_id}:{bond.ptnr2_auth_seq_id}:{bond.ptnr2_label_atom_id}"
24172446
)
2447+
ptnr1_label_atom_id = remove_last_digit_character(bond.ptnr1_label_atom_id)
2448+
ptnr2_label_atom_id = remove_last_digit_character(bond.ptnr2_label_atom_id)
24182449
try:
24192450
row_idx = get_token_index_from_composite_atom_id(
24202451
biomol,
24212452
bond.ptnr1_auth_asym_id,
24222453
int(bond.ptnr1_auth_seq_id),
2423-
bond.ptnr1_label_atom_id,
2454+
ptnr1_label_atom_id,
24242455
bond_atom_indices[ptnr1_atom_id],
24252456
ptnr1_is_polymer,
24262457
)
24272458
except Exception as e:
2428-
logger.warning(
2429-
f"Could not find a matching token index for token1 {ptnr1_atom_id} due to: {e}. "
2430-
"Skipping installing the current bond associated with this token."
2431-
)
2459+
if verbose:
2460+
logger.warning(
2461+
f"Could not find a matching token index for token1 {ptnr1_atom_id} due to: {e}. "
2462+
"Skipping installing the current bond associated with this token."
2463+
)
24322464
continue
24332465
try:
24342466
col_idx = get_token_index_from_composite_atom_id(
24352467
biomol,
24362468
bond.ptnr2_auth_asym_id,
24372469
int(bond.ptnr2_auth_seq_id),
2438-
bond.ptnr2_label_atom_id,
2470+
ptnr2_label_atom_id,
24392471
bond_atom_indices[ptnr2_atom_id],
24402472
ptnr2_is_polymer,
24412473
)
24422474
except Exception as e:
2443-
logger.warning(
2444-
f"Could not find a matching token index for token2 {ptnr1_atom_id} due to: {e}. "
2445-
"Skipping installing the current bond associated with this token."
2446-
)
2475+
if verbose:
2476+
logger.warning(
2477+
f"Could not find a matching token index for token2 {ptnr1_atom_id} due to: {e}. "
2478+
"Skipping installing the current bond associated with this token."
2479+
)
24472480
continue
24482481
token_bonds[row_idx, col_idx] = True
24492482
token_bonds[col_idx, row_idx] = True

alphafold3_pytorch/utils/data_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,15 @@ def deep_merge_dicts(
171171
# Otherwise, set/overwrite the key in dict1 with dict2's value
172172
dict1[key] = value
173173
return dict1
174+
175+
176+
@typecheck
177+
def remove_last_digit_character(string: str) -> str:
178+
"""Remove the last digit character from a string.
179+
180+
:param string: The string to remove the last digit character from.
181+
:return: The string with the last digit character removed
182+
"""
183+
if len(string) > 1 and string[-1].isdigit():
184+
return string[:-1]
185+
return string

tests/test_af3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def test_distogram_head():
514514
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
515515
@pytest.mark.parametrize('missing_atoms', (True, False))
516516
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
517-
@pytest.mark.parametrize('num_molecule_mods', (0, 5))
517+
@pytest.mark.parametrize('num_molecule_mods', (0, 4))
518518
@pytest.mark.parametrize('confidence_head_atom_resolution', (True, False))
519519
def test_alphafold3(
520520
window_atompair_inputs: bool,

0 commit comments

Comments
 (0)