Skip to content

Commit 419b024

Browse files
authored
Fix functionality of argparse arguments in pdb_dataset_curation.py (#35)
1 parent e4c5b7d commit 419b024

File tree

1 file changed

+88
-80
lines changed

1 file changed

+88
-80
lines changed

alphafold3_pytorch/pdb_dataset_curation.py

Lines changed: 88 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import pandas as pd
4141
import timeout_decorator
4242
from Bio.PDB import MMCIFIO, PDBIO, MMCIFParser, PDBParser
43-
from Bio.PDB.Atom import Atom
43+
from Bio.PDB.Atom import Atom, DisorderedAtom
4444
from Bio.PDB.NeighborSearch import NeighborSearch
4545
from Bio.PDB.Residue import Residue
4646
from Bio.PDB.Structure import Structure
@@ -50,76 +50,14 @@
5050

5151
from alphafold3_pytorch.typing import typecheck
5252

53-
# Parse command-line arguments
54-
55-
parser = argparse.ArgumentParser(
56-
description="Process mmCIF files to curate the AlphaFold 3 PDB dataset."
57-
)
58-
parser.add_argument(
59-
"--mmcif_dir",
60-
type=str,
61-
default=os.path.join("data", "mmCIF"),
62-
help="Path to the input directory containing mmCIF files to process.",
63-
)
64-
parser.add_argument(
65-
"--ccd_dir",
66-
type=str,
67-
default=os.path.join("data", "CCD"),
68-
help="Path to the directory containing CCD files to reference during data processing.",
69-
)
70-
parser.add_argument(
71-
"--output_dir",
72-
type=str,
73-
default=os.path.join("data", "PDB_set"),
74-
help="Path to the output directory in which to store processed mmCIF dataset files.",
75-
)
76-
parser.add_argument(
77-
"--skip_existing",
78-
action="store_true",
79-
help="Skip processing of existing output files.",
80-
)
81-
parser.add_argument(
82-
"--num_workers",
83-
type=int,
84-
default=1,
85-
help="Number of worker processes to use for parallel processing.",
86-
)
87-
parser.add_argument(
88-
"--worker_chunk_size",
89-
type=int,
90-
default=1,
91-
help="Size of mmCIF file chunks sent to worker processes.",
92-
)
93-
args = parser.parse_args("")
94-
95-
assert os.path.exists(args.mmcif_dir), f"Input directory {args.mmcif_dir} does not exist."
96-
assert os.path.exists(args.ccd_dir), f"CCD directory {args.ccd_dir} does not exist."
97-
assert os.path.exists(
98-
os.path.join(args.ccd_dir, "chem_comp_model.cif")
99-
), f"CCD ligands file not found in {args.ccd_dir}."
100-
assert os.path.exists(
101-
os.path.join(args.ccd_dir, "components.cif")
102-
), f"CCD components file not found in {args.ccd_dir}."
103-
os.makedirs(args.output_dir, exist_ok=True)
104-
10553
# Constants
10654

107-
Token = Residue | Atom
55+
Token = Residue | Atom | DisorderedAtom
10856

10957
PROCESS_STRUCTURE_MAX_SECONDS = (
11058
60 # Maximum time allocated to process a single structure (in seconds)
11159
)
11260

113-
# Section 2.5.4 of the AlphaFold 3 supplement
114-
115-
print("Loading the Chemical Component Dictionary (CCD) into memory...")
116-
CCD_READER_RESULTS = ccd_reader.read_pdb_components_file(
117-
# Load globally to share amongst all worker processes
118-
os.path.join(args.ccd_dir, "components.cif"),
119-
sanitize=False, # Reduce loading time
120-
)
121-
print("Finished loading the Chemical Component Dictionary (CCD) into memory.")
122-
12361
COVALENT_BOND_THRESHOLDS = {
12462
# Threshold distances for covalent bonds (in Ångströms)
12563
# These thresholds may vary depending on the specific types of bonds you are looking for
@@ -239,7 +177,7 @@ def filter_polymer_chains(
239177

240178

241179
@typecheck
242-
def filter_resolved_chains(structure: Structure) -> Structure:
180+
def filter_resolved_chains(structure: Structure) -> Structure | None:
243181
"""Filter based on number of resolved residues."""
244182
chains_to_remove = [
245183
chain.id
@@ -434,7 +372,7 @@ def remove_non_ccd_atoms(
434372

435373

436374
@typecheck
437-
def is_covalently_bonded(atom1: Atom, atom2: Atom) -> bool:
375+
def is_covalently_bonded(atom1: Atom | DisorderedAtom, atom2: Atom | DisorderedAtom) -> bool:
438376
"""
439377
Check if two atoms are covalently bonded.
440378
@@ -444,7 +382,7 @@ def is_covalently_bonded(atom1: Atom, atom2: Atom) -> bool:
444382
"""
445383
bond_type = tuple(sorted([atom1.element, atom2.element]))
446384
if bond_type in COVALENT_BOND_THRESHOLDS:
447-
return (atom1 - atom2) <= COVALENT_BOND_THRESHOLDS[bond_type]
385+
return ((atom1 - atom2) <= COVALENT_BOND_THRESHOLDS[bond_type]).item()
448386
return False
449387

450388

@@ -585,7 +523,7 @@ def get_token_center_atom(
585523
token: Token,
586524
protein_residue_center_atoms: Dict[str, str],
587525
nucleic_acid_residue_center_atoms: Dict[str, str],
588-
) -> Atom:
526+
) -> Atom | DisorderedAtom:
589527
"""Get center atom of a token."""
590528
if isinstance(token, Residue):
591529
if token.resname in protein_residue_center_atoms:
@@ -601,7 +539,7 @@ def get_token_center_atoms(
601539
tokens: List[Token],
602540
protein_residue_center_atoms: Dict[str, str],
603541
nucleic_acid_residue_center_atoms: Dict[str, str],
604-
) -> List[Atom]:
542+
) -> List[Atom | DisorderedAtom]:
605543
"""Get center atoms of tokens."""
606544
token_center_atoms = []
607545
for token in tokens:
@@ -712,22 +650,20 @@ def write_structure(structure: Structure, output_filepath: str):
712650

713651
@typecheck
714652
@timeout_decorator.timeout(PROCESS_STRUCTURE_MAX_SECONDS, use_signals=False)
715-
def process_structure_with_timeout(filepath: str, output_dir: str, skip_existing: bool = False):
653+
def process_structure_with_timeout(filepath: str, output_dir: str):
716654
"""
717655
Given an input mmCIF file, create a new processed mmCIF file
718656
using AlphaFold 3's PDB dataset filtering criteria under a
719657
timeout constraint.
720658
"""
721659
# Section 2.5.4 of the AlphaFold 3 supplement
722-
structure = parse_structure(filepath)
723-
output_file_dir = os.path.join(output_dir, structure.id[1:3])
724-
output_filepath = os.path.join(output_file_dir, f"{structure.id}.cif")
725-
if skip_existing and os.path.exists(output_filepath):
726-
print(f"Skipping existing output file: {output_filepath}")
727-
return
660+
structure_id = os.path.splitext(os.path.basename(filepath))[0]
661+
output_file_dir = os.path.join(output_dir, structure_id[1:3])
662+
output_filepath = os.path.join(output_file_dir, f"{structure_id}.cif")
728663
os.makedirs(output_file_dir, exist_ok=True)
729664

730665
# Filtering of targets
666+
structure = parse_structure(filepath)
731667
structure = filter_target(structure)
732668
if exists(structure):
733669
# Filtering of bioassemblies
@@ -755,13 +691,17 @@ def process_structure(args: Tuple[str, str, bool]):
755691
using AlphaFold 3's PDB dataset filtering criteria.
756692
"""
757693
filepath, output_dir, skip_existing = args
694+
structure_id = os.path.splitext(os.path.basename(filepath))[0]
695+
output_file_dir = os.path.join(output_dir, structure_id[1:3])
696+
output_filepath = os.path.join(output_file_dir, f"{structure_id}.cif")
697+
if skip_existing and os.path.exists(output_filepath):
698+
print(f"Skipping existing output file: {output_filepath}")
699+
return
700+
758701
try:
759-
process_structure_with_timeout(filepath, output_dir, skip_existing)
702+
process_structure_with_timeout(filepath, output_dir)
760703
except Exception as e:
761704
print(f"Skipping structure processing of {filepath} due to: {e}")
762-
structure_id = os.path.splitext(os.path.basename(filepath))[0]
763-
output_file_dir = os.path.join(output_dir, structure_id[1:3])
764-
output_filepath = os.path.join(output_file_dir, f"{structure_id}.cif")
765705
if os.path.exists(output_filepath):
766706
try:
767707
os.remove(output_filepath)
@@ -772,6 +712,74 @@ def process_structure(args: Tuple[str, str, bool]):
772712

773713

774714
if __name__ == '__main__':
715+
# Parse command-line arguments
716+
717+
parser = argparse.ArgumentParser(
718+
description="Process mmCIF files to curate the AlphaFold 3 PDB dataset."
719+
)
720+
parser.add_argument(
721+
"-i",
722+
"--mmcif_dir",
723+
type=str,
724+
default=os.path.join("data", "mmCIF"),
725+
help="Path to the input directory containing mmCIF files to process.",
726+
)
727+
parser.add_argument(
728+
"-c",
729+
"--ccd_dir",
730+
type=str,
731+
default=os.path.join("data", "CCD"),
732+
help="Path to the directory containing CCD files to reference during data processing.",
733+
)
734+
parser.add_argument(
735+
"-o",
736+
"--output_dir",
737+
type=str,
738+
default=os.path.join("data", "PDB_set"),
739+
help="Path to the output directory in which to store processed mmCIF dataset files.",
740+
)
741+
parser.add_argument(
742+
"-s",
743+
"--skip_existing",
744+
action="store_true",
745+
help="Skip processing of existing output files.",
746+
)
747+
parser.add_argument(
748+
"-n",
749+
"--num_workers",
750+
type=int,
751+
default=1,
752+
help="Number of worker processes to use for parallel processing.",
753+
)
754+
parser.add_argument(
755+
"-w",
756+
"--worker_chunk_size",
757+
type=int,
758+
default=1,
759+
help="Size of mmCIF file chunks sent to worker processes.",
760+
)
761+
args = parser.parse_args()
762+
763+
assert os.path.exists(args.mmcif_dir), f"Input directory {args.mmcif_dir} does not exist."
764+
assert os.path.exists(args.ccd_dir), f"CCD directory {args.ccd_dir} does not exist."
765+
assert os.path.exists(
766+
os.path.join(args.ccd_dir, "chem_comp_model.cif")
767+
), f"CCD ligands file not found in {args.ccd_dir}."
768+
assert os.path.exists(
769+
os.path.join(args.ccd_dir, "components.cif")
770+
), f"CCD components file not found in {args.ccd_dir}."
771+
os.makedirs(args.output_dir, exist_ok=True)
772+
773+
# Load the Chemical Component Dictionary (CCD) into memory
774+
775+
print("Loading the Chemical Component Dictionary (CCD) into memory...")
776+
CCD_READER_RESULTS = ccd_reader.read_pdb_components_file(
777+
# Load globally to share amongst all worker processes
778+
os.path.join(args.ccd_dir, "components.cif"),
779+
sanitize=False, # Reduce loading time
780+
)
781+
print("Finished loading the Chemical Component Dictionary (CCD) into memory.")
782+
775783
# Process structures across all worker processes
776784

777785
args_tuples = [

0 commit comments

Comments
 (0)