Skip to content

Commit bf679bf

Browse files
authored
Take a step towards loading, training, and sampling with mmCIF files (#74)
* Update trainer.py * Update data_pipeline.py * Create 7a4d-assembly1.cif * Update mmcif_writing.py * Update test_input.py * Update inputs.py * Fix test-time error in `inputs.py` * Update __init__.py * Update __init__.py
1 parent 3add55a commit bf679bf

File tree

7 files changed

+9682
-19
lines changed

7 files changed

+9682
-19
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
BatchedAtomInput,
4040
MoleculeInput,
4141
Alphafold3Input,
42+
PDBInput,
4243
maybe_transform_to_atom_input,
4344
maybe_transform_to_atom_inputs
4445
)
@@ -47,7 +48,8 @@
4748
Trainer,
4849
DataLoader,
4950
collate_inputs_to_batched_atom_input,
50-
alphafold3_inputs_to_batched_atom_input
51+
alphafold3_inputs_to_batched_atom_input,
52+
pdb_inputs_to_batched_atom_input,
5153
)
5254

5355
from alphafold3_pytorch.configs import (
@@ -90,10 +92,12 @@
9092
Alphafold3WithHubMixin,
9193
Alphafold3Config,
9294
AtomInput,
95+
PDBInput,
9396
Trainer,
9497
TrainerConfig,
9598
ConductorConfig,
9699
create_alphafold3_from_yaml,
97100
create_trainer_from_yaml,
98-
create_trainer_from_conductor_yaml
101+
create_trainer_from_conductor_yaml,
102+
pdb_inputs_to_batched_atom_input,
99103
]

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def make_mmcif_features(
140140

141141

142142
if __name__ == "__main__":
143-
filepath = "data/pdb_data/mmcifs/ak/7akd-assembly1.cif"
143+
filepath = os.path.join("data", "test", "7a4d-assembly1.cif")
144144
file_id = os.path.splitext(os.path.basename(filepath))[0]
145145

146146
mmcif_object = mmcif_parsing.parse_mmcif_object(

alphafold3_pytorch/data/mmcif_writing.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
"""An mmCIF file format writer."""
22

3+
import numpy as np
4+
5+
from typing import Optional
6+
37
from alphafold3_pytorch.common.biomolecule import (
48
_from_mmcif_object,
59
to_mmcif,
610
)
11+
from alphafold3_pytorch.data.data_pipeline import get_assembly
712
from alphafold3_pytorch.data.mmcif_parsing import MmcifObject
13+
from alphafold3_pytorch.utils.utils import exists
814

915

1016
def write_mmcif(
@@ -13,9 +19,21 @@ def write_mmcif(
1319
gapless_poly_seq: bool = True,
1420
insert_orig_atom_names: bool = True,
1521
insert_alphafold_mmcif_metadata: bool = True,
22+
sampled_atom_positions: Optional[np.ndarray] = None,
1623
):
1724
"""Write a BioPython `Structure` object to an mmCIF file using an intermediate `Biomolecule` object."""
18-
biomol = _from_mmcif_object(mmcif_object)
25+
biomol = (
26+
_from_mmcif_object(mmcif_object)
27+
if "assembly" in mmcif_object.file_id
28+
else get_assembly(_from_mmcif_object(mmcif_object))
29+
)
30+
if exists(sampled_atom_positions):
31+
atom_mask = biomol.atom_mask.astype(bool)
32+
assert biomol.atom_positions[atom_mask].shape == sampled_atom_positions.shape, (
33+
f"Expected sampled atom positions to have masked shape {biomol.atom_positions[atom_mask].shape}, "
34+
f"but got {sampled_atom_positions.shape}."
35+
)
36+
biomol.atom_positions[atom_mask] = sampled_atom_positions
1937
unique_res_atom_names = biomol.unique_res_atom_names if insert_orig_atom_names else None
2038
mmcif_string = to_mmcif(
2139
biomol,

alphafold3_pytorch/inputs.py

Lines changed: 158 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
from __future__ import annotations
22

3-
from functools import wraps, partial
4-
from dataclasses import dataclass, asdict, field
5-
from typing import Type, Literal, Callable, List, Any, Tuple
6-
3+
import einx
4+
import json
5+
import os
76
import torch
8-
from torch import tensor
7+
98
import torch.nn.functional as F
10-
import einx
119

12-
from rdkit.Chem import AllChem as Chem
10+
from dataclasses import dataclass, asdict, field
11+
from functools import wraps, partial
12+
from loguru import logger
13+
from pdbeccdutils.core import ccd_reader
14+
from rdkit import Chem
1315
from rdkit.Chem.rdchem import Atom, Mol
16+
from torch import tensor
17+
from typing import Type, Literal, Callable, List, Any, Tuple
1418

1519
from alphafold3_pytorch.attention import (
1620
pad_to_length
1721
)
1822

19-
from alphafold3_pytorch.tensor_typing import (
20-
typecheck,
21-
beartype_isinstance,
22-
Int, Bool, Float
23+
from alphafold3_pytorch.common.biomolecule import (
24+
_from_mmcif_object,
25+
get_residue_constants,
2326
)
2427

28+
from alphafold3_pytorch.data import mmcif_parsing
29+
from alphafold3_pytorch.data.data_pipeline import get_assembly
30+
2531
from alphafold3_pytorch.life import (
2632
HUMAN_AMINO_ACIDS,
2733
DNA_NUCLEOTIDES,
@@ -36,11 +42,46 @@
3642
reverse_complement_tensor
3743
)
3844

45+
from alphafold3_pytorch.tensor_typing import (
46+
typecheck,
47+
beartype_isinstance,
48+
Int, Bool, Float
49+
)
50+
3951
# constants
4052

4153
IS_MOLECULE_TYPES = 4
4254
ADDITIONAL_MOLECULE_FEATS = 5
4355

56+
CCD_COMPONENTS_FILEPATH = os.path.join("data", "ccd_data", "components.cif")
57+
CCD_COMPONENTS_SMILES_FILEPATH = os.path.join("data", "ccd_data", "components_smiles.json")
58+
59+
# load all SMILES strings in the PDB Chemical Component Dictionary (CCD)
60+
61+
CCD_COMPONENTS_SMILES = None
62+
63+
if os.path.exists(CCD_COMPONENTS_SMILES_FILEPATH):
64+
logger.info(f"Loading CCD component SMILES strings from {CCD_COMPONENTS_SMILES_FILEPATH}.")
65+
with open(CCD_COMPONENTS_SMILES_FILEPATH) as f:
66+
CCD_COMPONENTS_SMILES = json.load(f)
67+
elif os.path.exists(CCD_COMPONENTS_FILEPATH):
68+
logger.info(
69+
f"Loading CCD components from {CCD_COMPONENTS_FILEPATH} to extract all available SMILES strings (~3 minutes, one-time only)."
70+
)
71+
CCD_COMPONENTS = ccd_reader.read_pdb_components_file(
72+
CCD_COMPONENTS_FILEPATH,
73+
sanitize=False, # Reduce loading time
74+
)
75+
logger.info(
76+
f"Saving CCD component SMILES strings to {CCD_COMPONENTS_SMILES_FILEPATH} (one-time only)."
77+
)
78+
with open(CCD_COMPONENTS_SMILES_FILEPATH, "w") as f:
79+
CCD_COMPONENTS_SMILES = {
80+
ccd_code: Chem.MolToSmiles(CCD_COMPONENTS[ccd_code].component.mol)
81+
for ccd_code in CCD_COMPONENTS
82+
}
83+
json.dump(CCD_COMPONENTS_SMILES, f)
84+
4485
# functions
4586

4687
def exists(v):
@@ -740,9 +781,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
740781
unrepeated_sym_ids = [
741782
*[*range(len(i.proteins))],
742783
*[*range(len(i.ss_rna))],
743-
*[i for rna in i.ds_rna for i in range(2)],
784+
*[i for _ in i.ds_rna for i in range(2)],
744785
*[*range(len(i.ss_dna))],
745-
*[i for dna in i.ds_dna for i in range(2)],
786+
*[i for _ in i.ds_dna for i in range(2)],
746787
*([0] * len(mol_ligands)),
747788
0
748789
]
@@ -861,9 +902,112 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
861902
class PDBInput:
862903
filepath: str
863904

905+
@typecheck
906+
def extract_chain_sequences_from_chemical_components(
907+
chem_comps: List[mmcif_parsing.ChemComp],
908+
) -> Tuple[List[str], List[str], List[str], List[Mol | str]]:
909+
assert exists(CCD_COMPONENTS_SMILES), (
910+
f"The PDB Chemical Component Dictionary (CCD) components SMILES file {CCD_COMPONENTS_SMILES_FILEPATH} does not exist. "
911+
f"Please re-run this script after ensuring the preliminary CCD file {CCD_COMPONENTS_FILEPATH} has been downloaded according to this project's `README.md` file."
912+
f"After doing so, the SMILES file {CCD_COMPONENTS_SMILES_FILEPATH} will be cached locally and used for subsequent runs."
913+
)
914+
915+
current_chain_seq = []
916+
proteins, ss_dna, ss_rna, ligands = [], [], [], []
917+
918+
for idx, details in enumerate(chem_comps):
919+
residue_constants = get_residue_constants(details.type)
920+
restype = residue_constants.restype_3to1.get(details.id, "X")
921+
922+
# Protein residues
923+
924+
if "peptide" in details.type.lower():
925+
if not current_chain_seq:
926+
proteins.append(current_chain_seq)
927+
current_chain_seq.append(restype)
928+
# Reset current_chain_seq if the next residue is not a protein residue
929+
if idx + 1 < len(chem_comps) and "peptide" not in chem_comps[idx + 1].type.lower():
930+
current_chain_seq = []
931+
932+
# DNA residues
933+
934+
elif "dna" in details.type.lower():
935+
if not current_chain_seq:
936+
ss_dna.append(current_chain_seq)
937+
current_chain_seq.append(restype)
938+
# Reset current_chain_seq if the next residue is not a DNA residue
939+
if idx + 1 < len(chem_comps) and "dna" not in chem_comps[idx + 1].type.lower():
940+
current_chain_seq = []
941+
942+
# RNA residues
943+
944+
elif "rna" in details.type.lower():
945+
if not current_chain_seq:
946+
ss_rna.append(current_chain_seq)
947+
current_chain_seq.append(restype)
948+
# Reset current_chain_seq if the next residue is not a RNA residue
949+
if idx + 1 < len(chem_comps) and "rna" not in chem_comps[idx + 1].type.lower():
950+
current_chain_seq = []
951+
952+
# Ligand SMILES strings
953+
954+
else:
955+
if not current_chain_seq:
956+
ligands.append(current_chain_seq)
957+
current_chain_seq.append(CCD_COMPONENTS_SMILES[details.id])
958+
# Reset current_chain_seq after adding each ligand's SMILES string
959+
current_chain_seq = []
960+
961+
# Efficiently build sequence strings
962+
963+
proteins = ["".join(protein) for protein in proteins]
964+
ss_dna = ["".join(dna) for dna in ss_dna]
965+
ss_rna = ["".join(rna) for rna in ss_rna]
966+
ligands = ["".join(ligand) for ligand in ligands]
967+
968+
return proteins, ss_dna, ss_rna, ligands
969+
864970
@typecheck
865971
def pdb_input_to_alphafold3_input(pdb_input: PDBInput) -> Alphafold3Input:
866-
raise NotImplementedError
972+
filepath = pdb_input.filepath
973+
file_id = os.path.splitext(os.path.basename(filepath))[0]
974+
assert os.path.exists(filepath), f"PDB input file `{filepath}` does not exist."
975+
976+
mmcif_object = mmcif_parsing.parse_mmcif_object(
977+
filepath=filepath,
978+
file_id=file_id,
979+
)
980+
981+
biomol = (
982+
_from_mmcif_object(mmcif_object)
983+
if "assembly" in file_id
984+
else get_assembly(_from_mmcif_object(mmcif_object))
985+
)
986+
987+
chem_comp_table = {comp.id: comp for comp in biomol.chem_comp_table}
988+
chem_comp_details = [chem_comp_table[chemid] for chemid in biomol.chemid]
989+
990+
proteins, ss_dna, ss_rna, ligands = extract_chain_sequences_from_chemical_components(
991+
chem_comp_details
992+
)
993+
994+
atom_positions = biomol.atom_positions[biomol.atom_mask.astype(bool)]
995+
alphafold_input = Alphafold3Input(
996+
proteins=proteins,
997+
ss_dna=ss_dna,
998+
ss_rna=ss_rna,
999+
ligands=ligands,
1000+
atom_pos=torch.from_numpy(atom_positions.astype("float32")),
1001+
)
1002+
1003+
# TODO: Add support for AlphaFold 2-style amino/nucleic acid atom parametrization (i.e., 47 possible atom types per residue)
1004+
1005+
# TODO: Reference bonds from `biomol` instead of instantiating them within `Alphafold3Input`
1006+
1007+
# TODO: Ensure only polymer-ligand (e.g., protein/RNA/DNA-ligand) and ligand-ligand bonds
1008+
# (and bonds less than 2.4 Å) are referenced in `Alphafold3Input` (AF3 Supplement - Table 5, `token_bonds`)
1009+
1010+
return alphafold_input
8671011

8681012
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
8691013
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation

alphafold3_pytorch/trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AtomInput,
2424
BatchedAtomInput,
2525
Alphafold3Input,
26+
PDBInput,
2627
maybe_transform_to_atom_inputs,
2728
alphafold3_input_to_molecule_input
2829
)
@@ -200,6 +201,18 @@ def alphafold3_inputs_to_batched_atom_input(
200201
atom_inputs = maybe_transform_to_atom_inputs(inp)
201202
return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs)
202203

204+
@typecheck
205+
def pdb_inputs_to_batched_atom_input(
206+
inp: PDBInput | List[PDBInput],
207+
**collate_kwargs
208+
) -> BatchedAtomInput:
209+
210+
if isinstance(inp, PDBInput):
211+
inp = [inp]
212+
213+
atom_inputs = maybe_transform_to_atom_inputs(inp)
214+
return collate_inputs_to_batched_atom_input(atom_inputs, **collate_kwargs)
215+
203216
@typecheck
204217
def DataLoader(
205218
*args,

0 commit comments

Comments
 (0)