|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from functools import wraps, partial |
4 | | -from dataclasses import dataclass, asdict, field |
5 | | -from typing import Type, Literal, Callable, List, Any, Tuple |
6 | | - |
| 3 | +import einx |
| 4 | +import json |
| 5 | +import os |
7 | 6 | import torch |
8 | | -from torch import tensor |
| 7 | + |
9 | 8 | import torch.nn.functional as F |
10 | | -import einx |
11 | 9 |
|
12 | | -from rdkit.Chem import AllChem as Chem |
| 10 | +from dataclasses import dataclass, asdict, field |
| 11 | +from functools import wraps, partial |
| 12 | +from loguru import logger |
| 13 | +from pdbeccdutils.core import ccd_reader |
| 14 | +from rdkit import Chem |
13 | 15 | from rdkit.Chem.rdchem import Atom, Mol |
| 16 | +from torch import tensor |
| 17 | +from typing import Type, Literal, Callable, List, Any, Tuple |
14 | 18 |
|
15 | 19 | from alphafold3_pytorch.attention import ( |
16 | 20 | pad_to_length |
17 | 21 | ) |
18 | 22 |
|
19 | | -from alphafold3_pytorch.tensor_typing import ( |
20 | | - typecheck, |
21 | | - beartype_isinstance, |
22 | | - Int, Bool, Float |
| 23 | +from alphafold3_pytorch.common.biomolecule import ( |
| 24 | + _from_mmcif_object, |
| 25 | + get_residue_constants, |
23 | 26 | ) |
24 | 27 |
|
| 28 | +from alphafold3_pytorch.data import mmcif_parsing |
| 29 | +from alphafold3_pytorch.data.data_pipeline import get_assembly |
| 30 | + |
25 | 31 | from alphafold3_pytorch.life import ( |
26 | 32 | HUMAN_AMINO_ACIDS, |
27 | 33 | DNA_NUCLEOTIDES, |
|
36 | 42 | reverse_complement_tensor |
37 | 43 | ) |
38 | 44 |
|
| 45 | +from alphafold3_pytorch.tensor_typing import ( |
| 46 | + typecheck, |
| 47 | + beartype_isinstance, |
| 48 | + Int, Bool, Float |
| 49 | +) |
| 50 | + |
39 | 51 | # constants |
40 | 52 |
|
41 | 53 | IS_MOLECULE_TYPES = 4 |
42 | 54 | ADDITIONAL_MOLECULE_FEATS = 5 |
43 | 55 |
|
| 56 | +CCD_COMPONENTS_FILEPATH = os.path.join("data", "ccd_data", "components.cif") |
| 57 | +CCD_COMPONENTS_SMILES_FILEPATH = os.path.join("data", "ccd_data", "components_smiles.json") |
| 58 | + |
| 59 | +# load all SMILES strings in the PDB Chemical Component Dictionary (CCD) |
| 60 | + |
| 61 | +CCD_COMPONENTS_SMILES = None |
| 62 | + |
| 63 | +if os.path.exists(CCD_COMPONENTS_SMILES_FILEPATH): |
| 64 | + logger.info(f"Loading CCD component SMILES strings from {CCD_COMPONENTS_SMILES_FILEPATH}.") |
| 65 | + with open(CCD_COMPONENTS_SMILES_FILEPATH) as f: |
| 66 | + CCD_COMPONENTS_SMILES = json.load(f) |
| 67 | +elif os.path.exists(CCD_COMPONENTS_FILEPATH): |
| 68 | + logger.info( |
| 69 | + f"Loading CCD components from {CCD_COMPONENTS_FILEPATH} to extract all available SMILES strings (~3 minutes, one-time only)." |
| 70 | + ) |
| 71 | + CCD_COMPONENTS = ccd_reader.read_pdb_components_file( |
| 72 | + CCD_COMPONENTS_FILEPATH, |
| 73 | + sanitize=False, # Reduce loading time |
| 74 | + ) |
| 75 | + logger.info( |
| 76 | + f"Saving CCD component SMILES strings to {CCD_COMPONENTS_SMILES_FILEPATH} (one-time only)." |
| 77 | + ) |
| 78 | + with open(CCD_COMPONENTS_SMILES_FILEPATH, "w") as f: |
| 79 | + CCD_COMPONENTS_SMILES = { |
| 80 | + ccd_code: Chem.MolToSmiles(CCD_COMPONENTS[ccd_code].component.mol) |
| 81 | + for ccd_code in CCD_COMPONENTS |
| 82 | + } |
| 83 | + json.dump(CCD_COMPONENTS_SMILES, f) |
| 84 | + |
44 | 85 | # functions |
45 | 86 |
|
46 | 87 | def exists(v): |
@@ -740,9 +781,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]: |
740 | 781 | unrepeated_sym_ids = [ |
741 | 782 | *[*range(len(i.proteins))], |
742 | 783 | *[*range(len(i.ss_rna))], |
743 | | - *[i for rna in i.ds_rna for i in range(2)], |
| 784 | + *[i for _ in i.ds_rna for i in range(2)], |
744 | 785 | *[*range(len(i.ss_dna))], |
745 | | - *[i for dna in i.ds_dna for i in range(2)], |
| 786 | + *[i for _ in i.ds_dna for i in range(2)], |
746 | 787 | *([0] * len(mol_ligands)), |
747 | 788 | 0 |
748 | 789 | ] |
@@ -861,9 +902,112 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]: |
861 | 902 | class PDBInput: |
862 | 903 | filepath: str |
863 | 904 |
|
| 905 | +@typecheck |
| 906 | +def extract_chain_sequences_from_chemical_components( |
| 907 | + chem_comps: List[mmcif_parsing.ChemComp], |
| 908 | +) -> Tuple[List[str], List[str], List[str], List[Mol | str]]: |
| 909 | + assert exists(CCD_COMPONENTS_SMILES), ( |
| 910 | + f"The PDB Chemical Component Dictionary (CCD) components SMILES file {CCD_COMPONENTS_SMILES_FILEPATH} does not exist. " |
| 911 | + f"Please re-run this script after ensuring the preliminary CCD file {CCD_COMPONENTS_FILEPATH} has been downloaded according to this project's `README.md` file." |
| 912 | + f"After doing so, the SMILES file {CCD_COMPONENTS_SMILES_FILEPATH} will be cached locally and used for subsequent runs." |
| 913 | + ) |
| 914 | + |
| 915 | + current_chain_seq = [] |
| 916 | + proteins, ss_dna, ss_rna, ligands = [], [], [], [] |
| 917 | + |
| 918 | + for idx, details in enumerate(chem_comps): |
| 919 | + residue_constants = get_residue_constants(details.type) |
| 920 | + restype = residue_constants.restype_3to1.get(details.id, "X") |
| 921 | + |
| 922 | + # Protein residues |
| 923 | + |
| 924 | + if "peptide" in details.type.lower(): |
| 925 | + if not current_chain_seq: |
| 926 | + proteins.append(current_chain_seq) |
| 927 | + current_chain_seq.append(restype) |
| 928 | + # Reset current_chain_seq if the next residue is not a protein residue |
| 929 | + if idx + 1 < len(chem_comps) and "peptide" not in chem_comps[idx + 1].type.lower(): |
| 930 | + current_chain_seq = [] |
| 931 | + |
| 932 | + # DNA residues |
| 933 | + |
| 934 | + elif "dna" in details.type.lower(): |
| 935 | + if not current_chain_seq: |
| 936 | + ss_dna.append(current_chain_seq) |
| 937 | + current_chain_seq.append(restype) |
| 938 | + # Reset current_chain_seq if the next residue is not a DNA residue |
| 939 | + if idx + 1 < len(chem_comps) and "dna" not in chem_comps[idx + 1].type.lower(): |
| 940 | + current_chain_seq = [] |
| 941 | + |
| 942 | + # RNA residues |
| 943 | + |
| 944 | + elif "rna" in details.type.lower(): |
| 945 | + if not current_chain_seq: |
| 946 | + ss_rna.append(current_chain_seq) |
| 947 | + current_chain_seq.append(restype) |
| 948 | + # Reset current_chain_seq if the next residue is not a RNA residue |
| 949 | + if idx + 1 < len(chem_comps) and "rna" not in chem_comps[idx + 1].type.lower(): |
| 950 | + current_chain_seq = [] |
| 951 | + |
| 952 | + # Ligand SMILES strings |
| 953 | + |
| 954 | + else: |
| 955 | + if not current_chain_seq: |
| 956 | + ligands.append(current_chain_seq) |
| 957 | + current_chain_seq.append(CCD_COMPONENTS_SMILES[details.id]) |
| 958 | + # Reset current_chain_seq after adding each ligand's SMILES string |
| 959 | + current_chain_seq = [] |
| 960 | + |
| 961 | + # Efficiently build sequence strings |
| 962 | + |
| 963 | + proteins = ["".join(protein) for protein in proteins] |
| 964 | + ss_dna = ["".join(dna) for dna in ss_dna] |
| 965 | + ss_rna = ["".join(rna) for rna in ss_rna] |
| 966 | + ligands = ["".join(ligand) for ligand in ligands] |
| 967 | + |
| 968 | + return proteins, ss_dna, ss_rna, ligands |
| 969 | + |
864 | 970 | @typecheck |
865 | 971 | def pdb_input_to_alphafold3_input(pdb_input: PDBInput) -> Alphafold3Input: |
866 | | - raise NotImplementedError |
| 972 | + filepath = pdb_input.filepath |
| 973 | + file_id = os.path.splitext(os.path.basename(filepath))[0] |
| 974 | + assert os.path.exists(filepath), f"PDB input file `{filepath}` does not exist." |
| 975 | + |
| 976 | + mmcif_object = mmcif_parsing.parse_mmcif_object( |
| 977 | + filepath=filepath, |
| 978 | + file_id=file_id, |
| 979 | + ) |
| 980 | + |
| 981 | + biomol = ( |
| 982 | + _from_mmcif_object(mmcif_object) |
| 983 | + if "assembly" in file_id |
| 984 | + else get_assembly(_from_mmcif_object(mmcif_object)) |
| 985 | + ) |
| 986 | + |
| 987 | + chem_comp_table = {comp.id: comp for comp in biomol.chem_comp_table} |
| 988 | + chem_comp_details = [chem_comp_table[chemid] for chemid in biomol.chemid] |
| 989 | + |
| 990 | + proteins, ss_dna, ss_rna, ligands = extract_chain_sequences_from_chemical_components( |
| 991 | + chem_comp_details |
| 992 | + ) |
| 993 | + |
| 994 | + atom_positions = biomol.atom_positions[biomol.atom_mask.astype(bool)] |
| 995 | + alphafold_input = Alphafold3Input( |
| 996 | + proteins=proteins, |
| 997 | + ss_dna=ss_dna, |
| 998 | + ss_rna=ss_rna, |
| 999 | + ligands=ligands, |
| 1000 | + atom_pos=torch.from_numpy(atom_positions.astype("float32")), |
| 1001 | + ) |
| 1002 | + |
| 1003 | + # TODO: Add support for AlphaFold 2-style amino/nucleic acid atom parametrization (i.e., 47 possible atom types per residue) |
| 1004 | + |
| 1005 | + # TODO: Reference bonds from `biomol` instead of instantiating them within `Alphafold3Input` |
| 1006 | + |
| 1007 | + # TODO: Ensure only polymer-ligand (e.g., protein/RNA/DNA-ligand) and ligand-ligand bonds |
| 1008 | + # (and bonds less than 2.4 Å) are referenced in `Alphafold3Input` (AF3 Supplement - Table 5, `token_bonds`) |
| 1009 | + |
| 1010 | + return alphafold_input |
867 | 1011 |
|
868 | 1012 | # the config used for keeping track of all the disparate inputs and their transforms down to AtomInput |
869 | 1013 | # this can be preprocessed or will be taken care of automatically within the Trainer during data collation |
|
0 commit comments