Skip to content

Commit c1fa9aa

Browse files
authored
Refactor mmCIF parsing helper functions to add a new (dummy) data pipeline (#62)
* Update data_pipeline.py * Update mmcif_parsing.py * Create mmcif_writing.py * Update cluster_pdb_mmcifs.py * Update filter_pdb_mmcifs.py * Fix bug for protein clustering ratio in `cluster_pdb_mmcifs.py`
1 parent e6c4e96 commit c1fa9aa

File tree

5 files changed

+153
-141
lines changed

5 files changed

+153
-141
lines changed

alphafold3_pytorch/data/data_pipeline.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from alphafold3_pytorch.common import amino_acid_constants
7+
from alphafold3_pytorch.common.biomolecule import _from_mmcif_object
88
from alphafold3_pytorch.data import mmcif_parsing
99

1010
FeatureDict = MutableMapping[str, np.ndarray]
@@ -13,23 +13,17 @@
1313
def make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
1414
"""Construct a feature dict of sequence features."""
1515
features = {}
16-
features["restype"] = amino_acid_constants.sequence_to_onehot(
17-
sequence=sequence,
18-
mapping=amino_acid_constants.restype_order_with_x,
19-
map_unknown_to_x=True,
20-
)
2116
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
2217
features["domain_name"] = np.array([description.encode("utf-8")], dtype=object)
23-
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
2418
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
2519
features["sequence"] = np.array([sequence.encode("utf-8")], dtype=object)
2620
return features
2721

2822

29-
def make_mmcif_features(mmcif_object: mmcif_parsing.MmcifObject, chain_id: str) -> FeatureDict:
23+
def make_mmcif_features(mmcif_object: mmcif_parsing.MmcifObject) -> FeatureDict:
3024
"""Make features from an mmCIF object."""
31-
input_sequence = mmcif_object.chain_to_seqres[chain_id]
32-
description = "_".join([mmcif_object.file_id, chain_id])
25+
input_sequence = "".join(mmcif_object.chain_to_seqres[chain_id] for chain_id in mmcif_object.chain_to_seqres)
26+
description = mmcif_object.file_id
3327
num_res = len(input_sequence)
3428

3529
mmcif_feats = {}
@@ -42,15 +36,19 @@ def make_mmcif_features(mmcif_object: mmcif_parsing.MmcifObject, chain_id: str)
4236
)
4337
)
4438

45-
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
46-
mmcif_object=mmcif_object, chain_id=chain_id
47-
)
39+
biomol = _from_mmcif_object(mmcif_object)
4840

4941
# TODO: Expand the first bioassembly/model sequence and structure, to obtain a biologically relevant complex (AF3 Supplement, Section 2.1).
5042
# Reference: https://github.com/biotite-dev/biotite/blob/1045f43f80c77a0dc00865e924442385ce8f83ab/src/biotite/structure/io/pdbx/convert.py#L1441
5143

52-
mmcif_feats["all_atom_positions"] = all_atom_positions
53-
mmcif_feats["all_atom_mask"] = all_atom_mask
44+
mmcif_feats["all_atom_positions"] = biomol.atom_positions
45+
mmcif_feats["all_atom_mask"] = biomol.atom_mask
46+
mmcif_feats["b_factors"] = biomol.b_factors
47+
mmcif_feats["chain_index"] = biomol.chain_index
48+
mmcif_feats["chemid"] = biomol.chemid
49+
mmcif_feats["chemtype"] = biomol.chemtype
50+
mmcif_feats["residue_index"] = biomol.residue_index
51+
mmcif_feats["restype"] = biomol.restype
5452

5553
mmcif_feats["resolution"] = np.array([mmcif_object.header["resolution"]], dtype=np.float32)
5654

@@ -61,3 +59,13 @@ def make_mmcif_features(mmcif_object: mmcif_parsing.MmcifObject, chain_id: str)
6159
mmcif_feats["is_distillation"] = np.array(0.0, dtype=np.float32)
6260

6361
return mmcif_feats
62+
63+
64+
if __name__ == "__main__":
65+
mmcif_object = mmcif_parsing.parse_mmcif_object(
66+
# Load an example mmCIF file that includes
67+
# protein, nucleic acid, and ligand residues.
68+
filepath="data/pdb_data/mmcifs/16/316d.cif",
69+
file_id="316d",
70+
)
71+
mmcif_feats = make_mmcif_features(mmcif_object)

alphafold3_pytorch/data/mmcif_parsing.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io
66
import logging
77
from collections import defaultdict
8+
from operator import itemgetter
89
from typing import Any, Mapping, Optional, Sequence, Set, Tuple
910

1011
from Bio import PDB
@@ -645,3 +646,71 @@ def _get_complex_chains(
645646
def _is_set(data: str) -> bool:
646647
"""Returns False if data is a special mmCIF character indicating 'unset'."""
647648
return data not in (".", "?")
649+
650+
651+
def parse_mmcif_object(
652+
filepath: str, file_id: str, auth_chains: bool = True, auth_residues: bool = True
653+
) -> MmcifObject:
654+
"""Parse an mmCIF file into an `MmcifObject` containing a BioPython `Structure` object as well as associated metadata."""
655+
with open(filepath, "r") as f:
656+
mmcif_string = f.read()
657+
658+
parsing_result = parse(
659+
file_id=file_id,
660+
mmcif_string=mmcif_string,
661+
auth_chains=auth_chains,
662+
auth_residues=auth_residues,
663+
)
664+
665+
# Crash if an error is encountered. Any parsing errors should have
666+
# been dealt with beforehand (e.g., at the alignment stage).
667+
if parsing_result.mmcif_object is None:
668+
raise list(parsing_result.errors.values())[0]
669+
670+
return parsing_result.mmcif_object
671+
672+
673+
def filter_mmcif(mmcif_object: MmcifObject) -> MmcifObject:
674+
"""Filter an `MmcifObject` based on collected (atom/residue/chain) removal sets."""
675+
model = mmcif_object.structure
676+
677+
# Filter out specified chains
678+
chains_to_remove = set()
679+
680+
for chain in model:
681+
# Filter out specified residues
682+
residues_to_remove = set()
683+
assert len(chain) == len(mmcif_object.chem_comp_details[chain.id]), (
684+
f"Number of residues in chain {chain.id} does not match "
685+
f"number of chemical component details for this chain: {len(chain)} vs. "
686+
f"{len(mmcif_object.chem_comp_details[chain.id])}."
687+
)
688+
for res_index, residue in enumerate(chain):
689+
# Filter out specified atoms
690+
atoms_to_remove = set()
691+
for atom in residue:
692+
if atom.get_full_id() in mmcif_object.atoms_to_remove:
693+
atoms_to_remove.add(atom)
694+
if len(atoms_to_remove) == len(residue):
695+
residues_to_remove.add((res_index, residue))
696+
for atom in atoms_to_remove:
697+
residue.detach_child(atom.id)
698+
if residue.get_full_id() in mmcif_object.residues_to_remove:
699+
residues_to_remove.add((res_index, residue))
700+
if len(residues_to_remove) == len(chain):
701+
chains_to_remove.add(chain)
702+
for res_index, residue in sorted(residues_to_remove, key=itemgetter(0), reverse=True):
703+
del mmcif_object.chem_comp_details[chain.id][res_index]
704+
chain.detach_child(residue.id)
705+
if chain.get_full_id() in mmcif_object.chains_to_remove:
706+
chains_to_remove.add(chain)
707+
708+
for chain in chains_to_remove:
709+
model.detach_child(chain.id)
710+
mmcif_object.chem_comp_details.pop(chain.id)
711+
712+
mmcif_object.atoms_to_remove.clear()
713+
mmcif_object.residues_to_remove.clear()
714+
mmcif_object.chains_to_remove.clear()
715+
716+
return mmcif_object
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""An mmCIF file format writer."""
2+
3+
from typing import List
4+
5+
from alphafold3_pytorch.common.biomolecule import (
6+
_from_mmcif_object,
7+
get_residue_constants,
8+
to_mmcif,
9+
)
10+
from alphafold3_pytorch.data.mmcif_parsing import MmcifObject
11+
from alphafold3_pytorch.utils.data_utils import is_polymer
12+
13+
14+
def get_unique_res_atom_names(mmcif_object: MmcifObject) -> List[List[List[str]]]:
15+
"""Get atom names for each (e.g. ligand) "pseudoresidue" of each residue in each chain."""
16+
unique_res_atom_names = []
17+
for chain in mmcif_object.structure:
18+
chain_chem_comp = mmcif_object.chem_comp_details[chain.id]
19+
for res, res_chem_comp in zip(chain, chain_chem_comp):
20+
is_polymer_residue = is_polymer(res_chem_comp.type)
21+
residue_constants = get_residue_constants(res_chem_type=res_chem_comp.type)
22+
if is_polymer_residue:
23+
# For polymer residues, append the atom types directly.
24+
atoms_to_append = [residue_constants.atom_types]
25+
else:
26+
# For non-polymer residues, create a nested list of atom names.
27+
atoms_to_append = [
28+
[atom.name for _ in range(residue_constants.atom_type_num)] for atom in res
29+
]
30+
unique_res_atom_names.append(atoms_to_append)
31+
return unique_res_atom_names
32+
33+
34+
def write_mmcif(
35+
mmcif_object: MmcifObject,
36+
output_filepath: str,
37+
gapless_poly_seq: bool = True,
38+
insert_orig_atom_names: bool = True,
39+
insert_alphafold_mmcif_metadata: bool = True,
40+
):
41+
"""Write a BioPython `Structure` object to an mmCIF file using an intermediate `Biomolecule` object."""
42+
biomol = _from_mmcif_object(mmcif_object)
43+
unique_res_atom_names = (
44+
get_unique_res_atom_names(mmcif_object) if insert_orig_atom_names else None
45+
)
46+
mmcif_string = to_mmcif(
47+
biomol,
48+
mmcif_object.file_id,
49+
gapless_poly_seq=gapless_poly_seq,
50+
insert_alphafold_mmcif_metadata=insert_alphafold_mmcif_metadata,
51+
unique_res_atom_names=unique_res_atom_names,
52+
)
53+
with open(output_filepath, "w") as f:
54+
f.write(mmcif_string)

scripts/cluster_pdb_mmcifs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
from sklearn.cluster import AgglomerativeClustering
3636
from tqdm import tqdm
3737

38+
from alphafold3_pytorch.data import mmcif_parsing
3839
from alphafold3_pytorch.tensor_typing import IntType, typecheck
3940
from alphafold3_pytorch.utils.utils import exists, np_mode
40-
from scripts.filter_pdb_mmcifs import parse_mmcif_object
4141

4242
# Constants
4343

@@ -170,7 +170,7 @@ def parse_chain_sequences_and_interfaces_from_mmcif(
170170
"""
171171
assert filepath.endswith(".cif"), "The input file must be an mmCIF file."
172172
file_id = os.path.splitext(os.path.basename(filepath))[0]
173-
mmcif_object = parse_mmcif_object(filepath, file_id)
173+
mmcif_object = mmcif_parsing.parse_mmcif_object(filepath, file_id)
174174
model = mmcif_object.structure
175175

176176
# NOTE: After dataset filtering, only heavy (non-hydrogen) atoms remain in the structure
@@ -707,7 +707,7 @@ def cluster_interfaces(
707707
# Cluster proteins at 40% sequence homology
708708
AgglomerativeClustering(
709709
n_clusters=None,
710-
distance_threshold=40.0 + 1e-6,
710+
distance_threshold=60.0 + 1e-6,
711711
metric="precomputed",
712712
linkage="complete",
713713
).fit_predict(protein_dist_matrix)

0 commit comments

Comments
 (0)