Skip to content

Commit 74c4373

Browse files
authored
Add timeout decorator to pdb_dataset_curation.py (#32)
* Add timeout decorator to pdb_dataset_curation.py * Update pyproject.toml * Simplify timeout decorator in pdb_dataset_curation.py * Update pyproject.toml * Update pdb_dataset_curation.py
1 parent 0c1d641 commit 74c4373

File tree

2 files changed

+56
-37
lines changed

2 files changed

+56
-37
lines changed

alphafold3_pytorch/pdb_dataset_curation.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030

3131
# %%
3232
from __future__ import annotations
33+
3334
import argparse
3435
import glob
3536
import os
3637
import random
37-
from typing import Any, Dict, List, Set, Tuple
38+
from typing import Dict, List, Set, Tuple
3839

3940
import pandas as pd
41+
import timeout_decorator
4042
from Bio.PDB import MMCIFIO, PDBIO, MMCIFParser, PDBParser
4143
from Bio.PDB.Atom import Atom
4244
from Bio.PDB.NeighborSearch import NeighborSearch
@@ -104,6 +106,10 @@
104106

105107
Token = Residue | Atom
106108

109+
PROCESS_STRUCTURE_MAX_SECONDS = (
110+
60 # Maximum time allocated to process a single structure (in seconds)
111+
)
112+
107113
# Section 2.5.4 of the AlphaFold 3 supplement
108114

109115
print("Loading the Chemical Component Dictionary (CCD) into memory...")
@@ -200,18 +206,17 @@ def parse_structure(filepath: str) -> Structure:
200206
structure = parser.get_structure(structure_id, filepath)
201207
return structure
202208

209+
203210
@typecheck
204211
def filter_pdb_deposition_date(
205212
structure: Structure, cutoff_date: pd.Timestamp = pd.to_datetime("2021-09-30")
206213
) -> bool:
207214
"""Filter based on PDB deposition date."""
208-
if (
215+
return (
209216
"deposition_date" in structure.header
210217
and exists(structure.header["deposition_date"])
211218
and pd.to_datetime(structure.header["deposition_date"]) <= cutoff_date
212-
):
213-
return True
214-
return False
219+
)
215220

216221

217222
@typecheck
@@ -524,7 +529,7 @@ def remove_leaving_atoms(
524529

525530

526531
@typecheck
527-
def filter_large_ca_distances(structure: Structure) -> Structure:
532+
def filter_large_ca_distances(structure: Structure, max_distance: float = 10.0) -> Structure:
528533
"""
529534
Filter chains with large Ca atom distances.
530535
@@ -537,7 +542,7 @@ def filter_large_ca_distances(structure: Structure) -> Structure:
537542
ca_atoms = [res["CA"] for res in chain if "CA" in res]
538543
for i, ca1 in enumerate(ca_atoms[:-1]):
539544
ca2 = ca_atoms[i + 1]
540-
if (ca1 - ca2) > 10:
545+
if (ca1 - ca2) > max_distance:
541546
chains_to_remove.append(chain.id)
542547
break
543548

@@ -712,38 +717,48 @@ def process_structure(args: Tuple[str, str, bool]):
712717
using AlphaFold 3's PDB dataset filtering criteria.
713718
"""
714719
filepath, output_dir, skip_existing = args
715-
716-
# Section 2.5.4 of the AlphaFold 3 supplement
717720
try:
718-
structure = parse_structure(filepath)
719-
output_file_dir = os.path.join(output_dir, structure.id[1:3])
720-
output_filepath = os.path.join(output_file_dir, f"{structure.id}.cif")
721-
if skip_existing and os.path.exists(output_filepath):
722-
print(f"Skipping existing output file: {output_filepath}")
723-
return
724-
os.makedirs(output_file_dir, exist_ok=True)
725-
726-
# Filtering of targets
727-
structure = filter_target(structure)
728-
if exists(structure):
729-
# Filtering of bioassemblies
730-
structure = remove_hydrogens(structure)
731-
structure = remove_all_unknown_residue_chains(structure, STANDARD_RESIDUES)
732-
structure = remove_clashing_chains(structure)
733-
structure = remove_excluded_ligands(structure, LIGAND_EXCLUSION_LIST)
734-
structure = remove_non_ccd_atoms(structure, CCD_READER_RESULTS)
735-
structure = remove_leaving_atoms(structure, CCD_READER_RESULTS)
736-
structure = filter_large_ca_distances(structure)
737-
structure = select_closest_chains(
738-
structure, PROTEIN_RESIDUE_CENTER_ATOMS, NUCLEIC_ACID_RESIDUE_CENTER_ATOMS
739-
)
740-
structure = remove_crystallization_aids(structure, CRYSTALLOGRAPHY_METHODS)
741-
if list(structure.get_chains()):
742-
# Save processed structure
743-
write_structure(structure, output_filepath)
744-
print(f"Finished processing structure: {structure.id}")
721+
with timeout_decorator.timeout(PROCESS_STRUCTURE_MAX_SECONDS, use_signals=False):
722+
# Section 2.5.4 of the AlphaFold 3 supplement
723+
structure = parse_structure(filepath)
724+
output_file_dir = os.path.join(output_dir, structure.id[1:3])
725+
output_filepath = os.path.join(output_file_dir, f"{structure.id}.cif")
726+
if skip_existing and os.path.exists(output_filepath):
727+
print(f"Skipping existing output file: {output_filepath}")
728+
return
729+
os.makedirs(output_file_dir, exist_ok=True)
730+
731+
# Filtering of targets
732+
structure = filter_target(structure)
733+
if exists(structure):
734+
# Filtering of bioassemblies
735+
structure = remove_hydrogens(structure)
736+
structure = remove_all_unknown_residue_chains(structure, STANDARD_RESIDUES)
737+
structure = remove_clashing_chains(structure)
738+
structure = remove_excluded_ligands(structure, LIGAND_EXCLUSION_LIST)
739+
structure = remove_non_ccd_atoms(structure, CCD_READER_RESULTS)
740+
structure = remove_leaving_atoms(structure, CCD_READER_RESULTS)
741+
structure = filter_large_ca_distances(structure)
742+
structure = select_closest_chains(
743+
structure, PROTEIN_RESIDUE_CENTER_ATOMS, NUCLEIC_ACID_RESIDUE_CENTER_ATOMS
744+
)
745+
structure = remove_crystallization_aids(structure, CRYSTALLOGRAPHY_METHODS)
746+
if list(structure.get_chains()):
747+
# Save processed structure
748+
write_structure(structure, output_filepath)
749+
print(f"Finished processing structure: {structure.id}")
745750
except Exception as e:
746751
print(f"Skipping structure processing of {filepath} due to: {e}")
752+
structure_id = os.path.splitext(os.path.basename(filepath))[0]
753+
output_file_dir = os.path.join(output_dir, structure_id[1:3])
754+
output_filepath = os.path.join(output_file_dir, f"{structure_id}.cif")
755+
if os.path.exists(output_filepath):
756+
try:
757+
os.remove(output_filepath)
758+
except Exception as e:
759+
print(
760+
f"Failed to remove partially processed file {output_filepath} due to: {e}. Skipping its removal..."
761+
)
747762

748763

749764
if __name__ == '__main__':
@@ -754,5 +769,8 @@ def process_structure(args: Tuple[str, str, bool]):
754769
for filepath in glob.glob(os.path.join(args.mmcif_dir, "*", "*.cif"))
755770
]
756771
process_map(
757-
process_structure, args_tuples, max_workers=args.num_workers, chunksize=args.worker_chunk_size
772+
process_structure,
773+
args_tuples,
774+
max_workers=args.num_workers,
775+
chunksize=args.worker_chunk_size,
758776
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"pdbeccdutils>=0.8.5",
3939
"pydantic>=2.7.2",
4040
"taylor-series-linear-attention>=0.1.9",
41+
"timeout_decorator>=0.5.0",
4142
'torch_geometric',
4243
"torch>=2.1",
4344
"tqdm>=4.66.4",

0 commit comments

Comments
 (0)