77from functools import partial , wraps
88from itertools import groupby
99from collections import defaultdict
10- from collections .abc import Iterable
10+ from collections .abc import Iterableassignment
1111from dataclasses import asdict , dataclass , field
1212from typing import Any , Callable , Dict , List , Literal , Set , Tuple , Type
1313
2828
2929from pdbeccdutils .core import ccd_reader
3030
31- from rdkit import Chem
31+ from rdkit import Chem , RDLogger
3232from rdkit .Chem import AllChem , rdDetermineBonds
3333from rdkit .Chem .rdchem import Atom , Mol
3434from rdkit .Geometry import Point3D
6161 get_pdb_input_residue_molecule_type ,
6262 is_atomized_residue ,
6363 is_polymer ,
64+ remove_last_digit_character ,
6465)
6566from alphafold3_pytorch .utils .model_utils import exclusive_cumsum
6667from alphafold3_pytorch .utils .utils import default , exists , first , identity
6768
69+ # silence RDKit's warnings
70+
71+ RDLogger .DisableLog ("rdApp.*" )
72+
6873# constants
6974
7075IS_MOLECULE_TYPES = 5
8489MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
8590NUM_MOLECULE_IDS = len (HUMAN_AMINO_ACIDS ) + len (RNA_NUCLEOTIDES ) + len (DNA_NUCLEOTIDES ) + 2
8691
87- DEFAULT_NUM_MOLECULE_MODS = 5
92+ DEFAULT_NUM_MOLECULE_MODS = 4 # `mod_protein`, `mod_rna`, `mod_dna`, and `mod_unk`
8893ADDITIONAL_MOLECULE_FEATS = 5
8994
9095CCD_COMPONENTS_FILEPATH = os .path .join ('data' , 'ccd_data' , 'components.cif' )
@@ -669,7 +674,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
669674 atom_input = AtomInput (
670675 atom_inputs = atom_inputs_tensor ,
671676 atompair_inputs = atompair_inputs ,
672- molecule_atom_lens = tensor ( atom_lens , dtype = torch .long ),
677+ molecule_atom_lens = atom_lens .long ( ),
673678 molecule_ids = i .molecule_ids ,
674679 molecule_atom_indices = i .molecule_atom_indices ,
675680 distogram_atom_indices = i .distogram_atom_indices ,
@@ -1750,20 +1755,24 @@ def add_atom_positions_to_mol(
17501755
17511756
17521757def create_mol_from_atom_positions_and_types (
1758+ name : str ,
17531759 atom_positions : np .ndarray ,
17541760 element_types : List [str ],
17551761 missing_atom_indices : Set [int ],
17561762 num_bond_attempts : int = 2 ,
1763+ verbose : bool = False ,
17571764) -> Mol :
17581765 """Create an RDKit molecule from a NumPy array of atom positions and a list of their element
17591766 types.
17601767
1768+ :param name: The name of the molecule.
17611769 :param atom_positions: A NumPy array of shape (num_atoms, 3) containing the 3D coordinates of
17621770 each atom.
17631771 :param element_types: A list of element symbols for each atom in the molecule.
17641772 :param missing_atom_indices: A set of atom indices that are missing from the atom_positions
17651773 array.
17661774 :param num_bond_attempts: The number of attempts to determine the bonds in the molecule.
1775+ :param verbose: Whether to log warnings when bond determination fails.
17671776 :return: An RDKit molecule with the specified atom positions and element types.
17681777 """
17691778 if len (atom_positions ) != len (element_types ):
@@ -1772,6 +1781,7 @@ def create_mol_from_atom_positions_and_types(
17721781 # populate an empty editable molecule
17731782
17741783 mol = Chem .RWMol ()
1784+ mol .SetProp ("_Name" , name )
17751785
17761786 for element_type in element_types :
17771787 atom = Chem .Atom (element_type )
@@ -1789,18 +1799,23 @@ def create_mol_from_atom_positions_and_types(
17891799 # finalize molecule by inferring bonds
17901800
17911801 determined_bonds = False
1792- for _ in range (num_bond_attempts ):
1802+ for i in range (num_bond_attempts ):
17931803 try :
1794- rdDetermineBonds .DetermineBonds (mol , charge = Chem .GetFormalCharge (mol ))
1804+ charge = Chem .GetFormalCharge (mol )
1805+ rdDetermineBonds .DetermineBonds (mol , charge = charge )
17951806 determined_bonds = True
17961807 except Exception as e :
1797- logger .warning (
1798- f"Failed to determine bonds in the input molecule due to: { e } . "
1799- "Retrying once more."
1800- )
1808+ if verbose :
1809+ logger .warning (
1810+ f"Failed to determine bonds for the input molecule { name } due to: { e } . "
1811+ f"{ 'Retrying once more.' if i < num_bond_attempts - 1 else 'Terminating bond assignment.' } "
1812+ )
18011813 continue
18021814 if not determined_bonds :
1803- raise ValueError ("Failed to determine bonds in the input molecule." )
1815+ if verbose :
1816+ logger .warning (
1817+ "Failed to determine bonds in the input molecule. Skipping bond assignment."
1818+ )
18041819
18051820 mol = Chem .RemoveHs (mol )
18061821 Chem .SanitizeMol (mol )
@@ -1818,6 +1833,7 @@ def extract_template_molecules_from_biomolecule_chains(
18181833 chain_seqs : List [str ],
18191834 chain_chem_types : List [PDB_INPUT_RESIDUE_MOLECULE_TYPE ],
18201835 mol_keyname : str = "rdchem_mol" ,
1836+ verbose : bool = False ,
18211837) -> Tuple [List [Mol ], List [PDB_INPUT_RESIDUE_MOLECULE_TYPE ]]:
18221838 """Extract RDKit template molecules and their types for the residues of each `Biomolecule`
18231839 chain.
@@ -1902,23 +1918,26 @@ def extract_template_molecules_from_biomolecule_chains(
19021918 res_atom_type_indices = np .where (res_atom_positions .all (axis = - 1 ))[1 ]
19031919 res_atom_elements = [
19041920 # NOTE: here, we treat the first character of each atom type as its element symbol
1905- res_constants .element_types [idx ]
1921+ res_constants .element_types [idx ]. replace ( "ATM" , "*" )
19061922 for idx in res_atom_type_indices
19071923 ]
19081924 mol = create_mol_from_atom_positions_and_types (
19091925 # NOTE: for now, we construct molecules without referencing canonical
19101926 # SMILES strings, which means there are no missing molecule atoms by design
1911- res_atom_positions [res_atom_mask ],
1912- res_atom_elements ,
1927+ name = seq ,
1928+ atom_positions = res_atom_positions [res_atom_mask ],
1929+ element_types = res_atom_elements ,
19131930 missing_atom_indices = set (),
1931+ verbose = verbose ,
19141932 )
19151933 try :
19161934 mol = AllChem .AssignBondOrdersFromTemplate (template_mol , mol )
19171935 except Exception as e :
1918- logger .warning (
1919- f"Failed to assign bond orders from the template ligand molecule for residue { res } due to: { e } . "
1920- "Skipping bond order assignment."
1921- )
1936+ if verbose :
1937+ logger .warning (
1938+ f"Failed to assign bond orders from the template atomized molecule for residue { seq } due to: { e } . "
1939+ "Skipping bond order assignment."
1940+ )
19221941 res_index += mol .GetNumAtoms ()
19231942
19241943 # (Unmodified) polymer residues
@@ -1977,6 +1996,7 @@ def extract_template_molecules_from_biomolecule_chains(
19771996 res_atom_positions .reshape (- 1 , 3 ),
19781997 missing_atom_indices ,
19791998 )
1999+ mol .SetProp ("_Name" , res )
19802000 res_index += 1
19812001
19822002 mol_seq .append (mol )
@@ -2079,7 +2099,9 @@ def find_mismatched_symmetry(
20792099
20802100@typecheck
20812101def pdb_input_to_molecule_input (
2082- pdb_input : PDBInput , biomol : Biomolecule | None = None
2102+ pdb_input : PDBInput ,
2103+ biomol : Biomolecule | None = None ,
2104+ verbose : bool = False ,
20832105) -> MoleculeInput :
20842106 """Convert a PDBInput to a MoleculeInput."""
20852107 i = pdb_input
@@ -2173,16 +2195,19 @@ def pdb_input_to_molecule_input(
21732195 biomol ,
21742196 chain_seqs ,
21752197 chain_chem_types ,
2198+ verbose = verbose ,
21762199 )
21772200
21782201 # collect pooling lengths and atom-wise molecule types for each molecule,
2179- # along with a token-wise boolean tensor indicating whether each molecule is modified
2202+ # along with a token-wise one-hot tensor indicating whether each molecule is modified
2203+ # and, if so, which type of modification it has (e.g., peptide vs. RNA modification)
21802204 molecule_idx = 0
21812205 token_pool_lens = []
21822206 molecule_atom_types = []
21832207 is_molecule_mod = []
21842208 for mol , mol_type in zip (molecules , molecule_types ):
21852209 num_atoms = mol .GetNumAtoms ()
2210+ is_mol_mod_type = [False for _ in range (DEFAULT_NUM_MOLECULE_MODS )]
21862211 if is_atomized_residue (mol_type ):
21872212 # NOTE: in the paper, they treat each atom of the ligand and modified polymer residues as a token
21882213 token_pool_lens .extend ([1 ] * num_atoms )
@@ -2201,26 +2226,30 @@ def pdb_input_to_molecule_input(
22012226 is_mol_type_index = IS_LIGAND_INDEX
22022227 elif mol_type == "mod_protein" :
22032228 is_mol_type_index = IS_PROTEIN_INDEX
2229+ is_mol_mod_type_index = 0
22042230 elif mol_type == "mod_rna" :
22052231 is_mol_type_index = IS_RNA_INDEX
2232+ is_mol_mod_type_index = 1
22062233 elif mol_type == "mod_dna" :
22072234 is_mol_type_index = IS_DNA_INDEX
2235+ is_mol_mod_type_index = 2
22082236 else :
22092237 raise ValueError (f"Unrecognized molecule type: { mol_type } " )
22102238
22112239 is_molecule_types [molecule_type_row_idx , is_mol_type_index ] = True
22122240
2213- is_molecule_mod .extend ([True if "mod" in mol_type else False ] * num_atoms )
2241+ if "mod" in mol_type :
2242+ is_mol_mod_type [is_mol_mod_type_index ] = True
2243+ is_molecule_mod .extend ([is_mol_mod_type ] * num_atoms )
22142244
22152245 molecule_idx += num_atoms
22162246 else :
22172247 token_pool_lens .append (num_atoms )
22182248 molecule_atom_types .append (mol_type )
2219- is_molecule_mod .append (False )
2249+ is_molecule_mod .append (is_mol_mod_type )
22202250 molecule_idx += 1
22212251
22222252 # collect token center, distogram, and source-target atom indices for each token
2223- molecule_idx = 0
22242253 molecule_atom_indices = []
22252254 distogram_atom_indices = []
22262255 src_tgt_atom_indices = []
@@ -2415,35 +2444,39 @@ def pdb_input_to_molecule_input(
24152444 ptnr2_atom_id = (
24162445 f"{ bond .ptnr2_auth_asym_id } :{ bond .ptnr2_auth_seq_id } :{ bond .ptnr2_label_atom_id } "
24172446 )
2447+ ptnr1_label_atom_id = remove_last_digit_character (bond .ptnr1_label_atom_id )
2448+ ptnr2_label_atom_id = remove_last_digit_character (bond .ptnr2_label_atom_id )
24182449 try :
24192450 row_idx = get_token_index_from_composite_atom_id (
24202451 biomol ,
24212452 bond .ptnr1_auth_asym_id ,
24222453 int (bond .ptnr1_auth_seq_id ),
2423- bond . ptnr1_label_atom_id ,
2454+ ptnr1_label_atom_id ,
24242455 bond_atom_indices [ptnr1_atom_id ],
24252456 ptnr1_is_polymer ,
24262457 )
24272458 except Exception as e :
2428- logger .warning (
2429- f"Could not find a matching token index for token1 { ptnr1_atom_id } due to: { e } . "
2430- "Skipping installing the current bond associated with this token."
2431- )
2459+ if verbose :
2460+ logger .warning (
2461+ f"Could not find a matching token index for token1 { ptnr1_atom_id } due to: { e } . "
2462+ "Skipping installing the current bond associated with this token."
2463+ )
24322464 continue
24332465 try :
24342466 col_idx = get_token_index_from_composite_atom_id (
24352467 biomol ,
24362468 bond .ptnr2_auth_asym_id ,
24372469 int (bond .ptnr2_auth_seq_id ),
2438- bond . ptnr2_label_atom_id ,
2470+ ptnr2_label_atom_id ,
24392471 bond_atom_indices [ptnr2_atom_id ],
24402472 ptnr2_is_polymer ,
24412473 )
24422474 except Exception as e :
2443- logger .warning (
2444- f"Could not find a matching token index for token2 { ptnr1_atom_id } due to: { e } . "
2445- "Skipping installing the current bond associated with this token."
2446- )
2475+ if verbose :
2476+ logger .warning (
2477+ f"Could not find a matching token index for token2 { ptnr1_atom_id } due to: { e } . "
2478+ "Skipping installing the current bond associated with this token."
2479+ )
24472480 continue
24482481 token_bonds [row_idx , col_idx ] = True
24492482 token_bonds [col_idx , row_idx ] = True
0 commit comments