129129
130130from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
131131
132- from Bio .PDB .StructureBuilder import StructureBuilder
133132from Bio .PDB .Structure import Structure
134- from Bio .PDB .PDBIO import PDBIO
135- from Bio .PDB .DSSP import DSSP
136- import tempfile
133+ from Bio .PDB .StructureBuilder import StructureBuilder
137134
138135"""
139136global ein notation:
@@ -5291,8 +5288,6 @@ def __init__(
52915288 contact_mask_threshold : float = 8.0 ,
52925289 is_fine_tuning : bool = False ,
52935290 weight_dict_config : dict = None ,
5294- dssp_path : str = "mkdssp" ,
5295- use_inhouse_rsa_calculation : bool = False
52965291 ):
52975292 super ().__init__ ()
52985293 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
@@ -5306,10 +5301,6 @@ def __init__(
53065301 self .register_buffer ("dist_breaks" , dist_breaks )
53075302 self .register_buffer ('lddt_thresholds' , torch .tensor ([0.5 , 1.0 , 2.0 , 4.0 ]))
53085303
5309- self .dssp_path = dssp_path
5310-
5311- self .use_inhouse_rsa_calculation = use_inhouse_rsa_calculation
5312-
53135304 atom_type_radii = tensor ([
53145305 1.65 , # 0 - nitrogen
53155306 1.87 , # 1 - carbon alpha
@@ -5328,22 +5319,6 @@ def __init__(
53285319
53295320 self .register_buffer ('atom_radii' , atom_type_radii , persistent = False )
53305321
5331- @property
5332- def is_mkdssp_available (self ):
5333- """Check if `mkdssp` is available.
5334-
5335- :return: True if `mkdssp` is available
5336- """
5337- try :
5338- sh .which (self .dssp_path )
5339- return True
5340- except sh .ErrorReturnCode_1 :
5341- return False
5342-
5343- @property
5344- def can_calculate_unresolved_protein_rasa (self ):
5345- return self .is_mkdssp_available or self .use_inhouse_rsa_calculation
5346-
53475322 @typecheck
53485323 def compute_gpde (
53495324 self ,
@@ -5633,92 +5608,6 @@ def compute_weighted_lddt(
56335608
56345609 return weighted_lddt
56355610
5636- @typecheck
5637- def _compute_unresolved_rasa (
5638- self ,
5639- unresolved_cid : int ,
5640- unresolved_residue_mask : Bool [" n" ],
5641- asym_id : Int [" n" ],
5642- molecule_ids : Int [" n" ],
5643- molecule_atom_lens : Int [" n" ],
5644- atom_pos : Float ["m 3" ],
5645- atom_mask : Bool [" m" ],
5646- ) -> Float ["" ]:
5647- """Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
5648-
5649- unresolved_cid: asym_id for protein chains with unresolved residues
5650- unresolved_residue_mask: True for unresolved residues, False for resolved residues
5651- asym_id: asym_id for each residue
5652- molecule_ids: molecule_ids for each residue
5653- molecule_atom_lens: number of atoms for each residue
5654- atom_pos: [m 3] atom positions
5655- atom_mask: True for valid atoms, False for missing/padding atoms
5656- :return: unresolved RASA
5657- """
5658-
5659- assert self .can_calculate_unresolved_protein_rasa , "`mkdssp` needs to be installed"
5660-
5661- residue_constants = get_residue_constants (res_chem_index = IS_PROTEIN )
5662-
5663- device = atom_pos .device
5664- dtype = atom_pos .dtype
5665- num_atom = atom_pos .shape [0 ]
5666-
5667- chain_mask = asym_id == unresolved_cid
5668- chain_unresolved_residue_mask = unresolved_residue_mask [chain_mask ]
5669- chain_asym_id = asym_id [chain_mask ]
5670- chain_molecule_ids = molecule_ids [chain_mask ]
5671- chain_molecule_atom_lens = molecule_atom_lens [chain_mask ]
5672-
5673- chain_mask_to_atom = torch .repeat_interleave (chain_mask , molecule_atom_lens )
5674-
5675- # if there's padding in num atom
5676- num_pad = num_atom - molecule_atom_lens .sum ()
5677- if num_pad > 0 :
5678- chain_mask_to_atom = F .pad (chain_mask_to_atom , (0 , num_pad ), value = False )
5679-
5680- chain_atom_pos = atom_pos [chain_mask_to_atom ]
5681- chain_atom_mask = atom_mask [chain_mask_to_atom ]
5682-
5683- structure = protein_structure_from_feature (
5684- chain_asym_id ,
5685- chain_molecule_ids ,
5686- chain_molecule_atom_lens ,
5687- chain_atom_pos ,
5688- chain_atom_mask ,
5689- )
5690-
5691- with tempfile .NamedTemporaryFile (suffix = ".pdb" , delete = True ) as temp_file :
5692- temp_file_path = temp_file .name
5693-
5694- pdb_writer = PDBIO ()
5695- pdb_writer .set_structure (structure )
5696- pdb_writer .save (temp_file_path )
5697- dssp = DSSP (structure [0 ], temp_file_path , dssp = self .dssp_path )
5698- dssp_dict = dict (dssp )
5699-
5700- rasa = []
5701- aatypes = []
5702- for residue in structure .get_residues ():
5703- rsa = float (dssp_dict .get ((residue .get_full_id ()[2 ], residue .id ))[3 ])
5704- rasa .append (rsa )
5705-
5706- aatype = dssp_dict .get ((residue .get_full_id ()[2 ], residue .id ))[1 ]
5707- aatypes .append (residue_constants .restype_order [aatype ])
5708-
5709- rasa = torch .tensor (rasa , dtype = dtype , device = device )
5710- aatypes = torch .tensor (aatypes , device = device ).int ()
5711-
5712- unresolved_aatypes = aatypes [chain_unresolved_residue_mask ]
5713- unresolved_molecule_ids = chain_molecule_ids [chain_unresolved_residue_mask ]
5714-
5715- assert torch .equal (
5716- unresolved_aatypes , unresolved_molecule_ids
5717- ), "aatype not match for input feature and structure"
5718- unresolved_rasa = rasa [chain_unresolved_residue_mask ]
5719-
5720- return unresolved_rasa .mean ()
5721-
57225611 @typecheck
57235612 def calc_atom_access_surface_score_from_structure (
57245613 self ,
@@ -5758,7 +5647,7 @@ def calc_atom_access_surface_score(
57585647 atom_pos : Float ['m 3' ],
57595648 atom_type : Int ['m' ],
57605649 molecule_atom_lens : Int ['n' ] | None = None ,
5761- fibonacci_sphere_n = 200 , # they use 200 in mkdssp, but can be tailored for efficiency
5650+ fibonacci_sphere_n = 200 , # more points equal better approximation at cost of compute
57625651 atom_distance_min_thres = 1e-4
57635652 ) -> Float ['m' ] | Float ['n' ]:
57645653
@@ -5852,7 +5741,7 @@ def calc_atom_access_surface_score(
58525741 return rasa
58535742
58545743 @typecheck
5855- def _inhouse_compute_unresolved_rasa (
5744+ def _compute_unresolved_rasa (
58565745 self ,
58575746 unresolved_cid : int ,
58585747 unresolved_residue_mask : Bool ["n" ],
@@ -5876,8 +5765,6 @@ def _inhouse_compute_unresolved_rasa(
58765765 :return: unresolved RASA
58775766 """
58785767
5879- assert self .can_calculate_unresolved_protein_rasa , "`mkdssp` needs to be installed"
5880-
58815768 num_atom = atom_pos .shape [0 ]
58825769
58835770 chain_mask = asym_id == unresolved_cid
@@ -5949,16 +5836,8 @@ def compute_unresolved_rasa(
59495836 weight = weight_dict .get ("unresolved" , {}).get ("unresolved" , None )
59505837 assert weight , "Weight not found for unresolved"
59515838
5952- # for migrating to a rewritten computation of RSA
5953- # to remove mkdssp dependency for model selection
5954-
5955- if self .use_inhouse_rsa_calculation :
5956- compute_unresolved_rasa_function = self ._inhouse_compute_unresolved_rasa
5957- else :
5958- compute_unresolved_rasa_function = self ._compute_unresolved_rasa
5959-
59605839 unresolved_rasa = [
5961- compute_unresolved_rasa_function (* args )
5840+ self . _compute_unresolved_rasa (* args )
59625841 for args in zip (
59635842 unresolved_cid ,
59645843 unresolved_residue_mask ,
0 commit comments