2929#
3030
3131# %%
32+ from __future__ import annotations
3233import argparse
3334import glob
3435import os
3536import random
36- from typing import Any , Dict , List , Optional , Set , Tuple , Union
37+ from typing import Any , Dict , List , Set , Tuple
3738
3839import pandas as pd
3940from Bio .PDB import MMCIFIO , PDBIO , MMCIFParser , PDBParser
4546from pdbeccdutils .core .ccd_reader import CCDReaderResult
4647from tqdm .contrib .concurrent import process_map
4748
49+ from alphafold3_pytorch .typing import typecheck
50+
4851# Parse command-line arguments
4952
5053parser = argparse .ArgumentParser (
99102
100103# Constants
101104
102- Token = Union [ Residue , Atom ]
105+ Token = Residue | Atom
103106
104107# Section 2.5.4 of the AlphaFold 3 supplement
105108
@@ -184,6 +187,7 @@ def exists(v: Any) -> bool:
184187 return v is not None
185188
186189
190+ @typecheck
187191def parse_structure (filepath : str ) -> Structure :
188192 """Parse a structure from a PDB or mmCIF file."""
189193 if filepath .endswith (".pdb" ):
@@ -196,7 +200,7 @@ def parse_structure(filepath: str) -> Structure:
196200 structure = parser .get_structure (structure_id , filepath )
197201 return structure
198202
199-
203+ @ typecheck
200204def filter_pdb_deposition_date (
201205 structure : Structure , cutoff_date : pd .Timestamp = pd .to_datetime ("2021-09-30" )
202206) -> bool :
@@ -210,17 +214,17 @@ def filter_pdb_deposition_date(
210214 return False
211215
212216
217+ @typecheck
213218def filter_resolution (structure : Structure , max_resolution : float = 9.0 ) -> bool :
214219 """Filter based on resolution."""
215- if (
220+ return (
216221 "resolution" in structure .header
217222 and exists (structure .header ["resolution" ])
218223 and structure .header ["resolution" ] <= max_resolution
219- ):
220- return True
221- return False
224+ )
222225
223226
227+ @typecheck
224228def filter_polymer_chains (
225229 structure : Structure , max_chains : int = 1000 , for_training : bool = False
226230) -> bool :
@@ -229,6 +233,7 @@ def filter_polymer_chains(
229233 return count <= (300 if for_training else max_chains )
230234
231235
236+ @typecheck
232237def filter_resolved_chains (structure : Structure ) -> Structure :
233238 """Filter based on number of resolved residues."""
234239 chains_to_remove = [
@@ -243,7 +248,8 @@ def filter_resolved_chains(structure: Structure) -> Structure:
243248 return structure if list (structure .get_chains ()) else None
244249
245250
246- def filter_target (structure : Structure ) -> Optional [Structure ]:
251+ @typecheck
252+ def filter_target (structure : Structure ) -> Structure | None :
247253 """Filter a target based on various criteria."""
248254 target_passes_prefilters = (
249255 filter_pdb_deposition_date (structure )
@@ -253,6 +259,7 @@ def filter_target(structure: Structure) -> Optional[Structure]:
253259 return filter_resolved_chains (structure ) if target_passes_prefilters else None
254260
255261
262+ @typecheck
256263def remove_hydrogens (structure : Structure , remove_waters : bool = True ) -> Structure :
257264 """
258265 Remove hydrogens (and optionally waters) from a structure.
@@ -284,6 +291,7 @@ def remove_hydrogens(structure: Structure, remove_waters: bool = True) -> Struct
284291 return structure
285292
286293
294+ @typecheck
287295def remove_all_unknown_residue_chains (
288296 structure : Structure , standard_residues : Set [str ]
289297) -> Structure :
@@ -300,6 +308,7 @@ def remove_all_unknown_residue_chains(
300308 return structure
301309
302310
311+ @typecheck
303312def remove_clashing_chains (
304313 structure : Structure , clash_threshold : float = 1.7 , clash_percentage : float = 0.3
305314) -> Structure :
@@ -360,6 +369,7 @@ def remove_clashing_chains(
360369 return structure
361370
362371
372+ @typecheck
363373def remove_excluded_ligands (structure : Structure , ligand_exclusion_list : Set [str ]) -> Structure :
364374 """
365375 Remove ligands in the exclusion list.
@@ -385,6 +395,7 @@ def remove_excluded_ligands(structure: Structure, ligand_exclusion_list: Set[str
385395 return structure
386396
387397
398+ @typecheck
388399def remove_non_ccd_atoms (
389400 structure : Structure , ccd_reader_results : Dict [str , CCDReaderResult ]
390401) -> Structure :
@@ -417,6 +428,7 @@ def remove_non_ccd_atoms(
417428 return structure
418429
419430
431+ @typecheck
420432def is_covalently_bonded (atom1 : Atom , atom2 : Atom ) -> bool :
421433 """
422434 Check if two atoms are covalently bonded.
@@ -431,6 +443,7 @@ def is_covalently_bonded(atom1: Atom, atom2: Atom) -> bool:
431443 return False
432444
433445
446+ @typecheck
434447def remove_leaving_atoms (
435448 structure : Structure , ccd_reader_results : Dict [str , CCDReaderResult ]
436449) -> Structure :
@@ -510,6 +523,7 @@ def remove_leaving_atoms(
510523 return structure
511524
512525
526+ @typecheck
513527def filter_large_ca_distances (structure : Structure ) -> Structure :
514528 """
515529 Filter chains with large Ca atom distances.
@@ -533,6 +547,7 @@ def filter_large_ca_distances(structure: Structure) -> Structure:
533547 return structure
534548
535549
550+ @typecheck
536551def select_closest_chains (
537552 structure : Structure ,
538553 protein_residue_center_atoms : Dict [str , str ],
@@ -541,6 +556,7 @@ def select_closest_chains(
541556) -> Structure :
542557 """Select the closest chains in large bioassemblies."""
543558
559+ @typecheck
544560 def get_tokens_from_residues (
545561 residues : List [Residue ],
546562 protein_residue_center_atoms : Dict [str , str ],
@@ -559,6 +575,7 @@ def get_tokens_from_residues(
559575 tokens .append (atom )
560576 return tokens
561577
578+ @typecheck
562579 def get_token_center_atom (
563580 token : Token ,
564581 protein_residue_center_atoms : Dict [str , str ],
@@ -574,6 +591,7 @@ def get_token_center_atom(
574591 token_center_atom = token
575592 return token_center_atom
576593
594+ @typecheck
577595 def get_token_center_atoms (
578596 tokens : List [Token ],
579597 protein_residue_center_atoms : Dict [str , str ],
@@ -588,6 +606,7 @@ def get_token_center_atoms(
588606 token_center_atoms .append (token_center_atom )
589607 return token_center_atoms
590608
609+ @typecheck
591610 def get_interface_tokens (
592611 tokens : List [Token ],
593612 protein_residue_center_atoms : Dict [str , str ],
@@ -645,6 +664,7 @@ def get_interface_tokens(
645664 return structure
646665
647666
667+ @typecheck
648668def remove_crystallization_aids (
649669 structure : Structure , crystallography_methods : Dict [str , Set [str ]]
650670) -> Structure :
@@ -672,6 +692,7 @@ def remove_crystallization_aids(
672692 return structure
673693
674694
695+ @typecheck
675696def write_structure (structure : Structure , output_filepath : str ):
676697 """Write a structure to a PDB or mmCIF file."""
677698 if output_filepath .endswith (".pdb" ):
@@ -684,6 +705,7 @@ def write_structure(structure: Structure, output_filepath: str):
684705 io .save (output_filepath )
685706
686707
708+ @typecheck
687709def process_structure (args : Tuple [str , str , bool ]):
688710 """
689711 Given an input mmCIF file, create a new processed mmCIF file
0 commit comments