Skip to content

Commit a04e0cc

Browse files
committed
farewell mkdssp
1 parent 8975786 commit a04e0cc

File tree

2 files changed

+5
-131
lines changed

2 files changed

+5
-131
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,8 @@
129129

130130
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
131131

132-
from Bio.PDB.StructureBuilder import StructureBuilder
133132
from Bio.PDB.Structure import Structure
134-
from Bio.PDB.PDBIO import PDBIO
135-
from Bio.PDB.DSSP import DSSP
136-
import tempfile
133+
from Bio.PDB.StructureBuilder import StructureBuilder
137134

138135
"""
139136
global ein notation:
@@ -5291,8 +5288,6 @@ def __init__(
52915288
contact_mask_threshold: float = 8.0,
52925289
is_fine_tuning: bool = False,
52935290
weight_dict_config: dict = None,
5294-
dssp_path: str = "mkdssp",
5295-
use_inhouse_rsa_calculation: bool = False
52965291
):
52975292
super().__init__()
52985293
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
@@ -5306,10 +5301,6 @@ def __init__(
53065301
self.register_buffer("dist_breaks", dist_breaks)
53075302
self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0]))
53085303

5309-
self.dssp_path = dssp_path
5310-
5311-
self.use_inhouse_rsa_calculation = use_inhouse_rsa_calculation
5312-
53135304
atom_type_radii = tensor([
53145305
1.65, # 0 - nitrogen
53155306
1.87, # 1 - carbon alpha
@@ -5328,22 +5319,6 @@ def __init__(
53285319

53295320
self.register_buffer('atom_radii', atom_type_radii, persistent = False)
53305321

5331-
@property
5332-
def is_mkdssp_available(self):
5333-
"""Check if `mkdssp` is available.
5334-
5335-
:return: True if `mkdssp` is available
5336-
"""
5337-
try:
5338-
sh.which(self.dssp_path)
5339-
return True
5340-
except sh.ErrorReturnCode_1:
5341-
return False
5342-
5343-
@property
5344-
def can_calculate_unresolved_protein_rasa(self):
5345-
return self.is_mkdssp_available or self.use_inhouse_rsa_calculation
5346-
53475322
@typecheck
53485323
def compute_gpde(
53495324
self,
@@ -5633,92 +5608,6 @@ def compute_weighted_lddt(
56335608

56345609
return weighted_lddt
56355610

5636-
@typecheck
5637-
def _compute_unresolved_rasa(
5638-
self,
5639-
unresolved_cid: int,
5640-
unresolved_residue_mask: Bool[" n"],
5641-
asym_id: Int[" n"],
5642-
molecule_ids: Int[" n"],
5643-
molecule_atom_lens: Int[" n"],
5644-
atom_pos: Float["m 3"],
5645-
atom_mask: Bool[" m"],
5646-
) -> Float[""]:
5647-
"""Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
5648-
5649-
unresolved_cid: asym_id for protein chains with unresolved residues
5650-
unresolved_residue_mask: True for unresolved residues, False for resolved residues
5651-
asym_id: asym_id for each residue
5652-
molecule_ids: molecule_ids for each residue
5653-
molecule_atom_lens: number of atoms for each residue
5654-
atom_pos: [m 3] atom positions
5655-
atom_mask: True for valid atoms, False for missing/padding atoms
5656-
:return: unresolved RASA
5657-
"""
5658-
5659-
assert self.can_calculate_unresolved_protein_rasa, "`mkdssp` needs to be installed"
5660-
5661-
residue_constants = get_residue_constants(res_chem_index=IS_PROTEIN)
5662-
5663-
device = atom_pos.device
5664-
dtype = atom_pos.dtype
5665-
num_atom = atom_pos.shape[0]
5666-
5667-
chain_mask = asym_id == unresolved_cid
5668-
chain_unresolved_residue_mask = unresolved_residue_mask[chain_mask]
5669-
chain_asym_id = asym_id[chain_mask]
5670-
chain_molecule_ids = molecule_ids[chain_mask]
5671-
chain_molecule_atom_lens = molecule_atom_lens[chain_mask]
5672-
5673-
chain_mask_to_atom = torch.repeat_interleave(chain_mask, molecule_atom_lens)
5674-
5675-
# if there's padding in num atom
5676-
num_pad = num_atom - molecule_atom_lens.sum()
5677-
if num_pad > 0:
5678-
chain_mask_to_atom = F.pad(chain_mask_to_atom, (0, num_pad), value=False)
5679-
5680-
chain_atom_pos = atom_pos[chain_mask_to_atom]
5681-
chain_atom_mask = atom_mask[chain_mask_to_atom]
5682-
5683-
structure = protein_structure_from_feature(
5684-
chain_asym_id,
5685-
chain_molecule_ids,
5686-
chain_molecule_atom_lens,
5687-
chain_atom_pos,
5688-
chain_atom_mask,
5689-
)
5690-
5691-
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as temp_file:
5692-
temp_file_path = temp_file.name
5693-
5694-
pdb_writer = PDBIO()
5695-
pdb_writer.set_structure(structure)
5696-
pdb_writer.save(temp_file_path)
5697-
dssp = DSSP(structure[0], temp_file_path, dssp=self.dssp_path)
5698-
dssp_dict = dict(dssp)
5699-
5700-
rasa = []
5701-
aatypes = []
5702-
for residue in structure.get_residues():
5703-
rsa = float(dssp_dict.get((residue.get_full_id()[2], residue.id))[3])
5704-
rasa.append(rsa)
5705-
5706-
aatype = dssp_dict.get((residue.get_full_id()[2], residue.id))[1]
5707-
aatypes.append(residue_constants.restype_order[aatype])
5708-
5709-
rasa = torch.tensor(rasa, dtype=dtype, device=device)
5710-
aatypes = torch.tensor(aatypes, device=device).int()
5711-
5712-
unresolved_aatypes = aatypes[chain_unresolved_residue_mask]
5713-
unresolved_molecule_ids = chain_molecule_ids[chain_unresolved_residue_mask]
5714-
5715-
assert torch.equal(
5716-
unresolved_aatypes, unresolved_molecule_ids
5717-
), "aatype not match for input feature and structure"
5718-
unresolved_rasa = rasa[chain_unresolved_residue_mask]
5719-
5720-
return unresolved_rasa.mean()
5721-
57225611
@typecheck
57235612
def calc_atom_access_surface_score_from_structure(
57245613
self,
@@ -5758,7 +5647,7 @@ def calc_atom_access_surface_score(
57585647
atom_pos: Float['m 3'],
57595648
atom_type: Int['m'],
57605649
molecule_atom_lens: Int['n'] | None = None,
5761-
fibonacci_sphere_n = 200, # they use 200 in mkdssp, but can be tailored for efficiency
5650+
fibonacci_sphere_n = 200, # more points equal better approximation at cost of compute
57625651
atom_distance_min_thres = 1e-4
57635652
) -> Float['m'] | Float['n']:
57645653

@@ -5852,7 +5741,7 @@ def calc_atom_access_surface_score(
58525741
return rasa
58535742

58545743
@typecheck
5855-
def _inhouse_compute_unresolved_rasa(
5744+
def _compute_unresolved_rasa(
58565745
self,
58575746
unresolved_cid: int,
58585747
unresolved_residue_mask: Bool["n"],
@@ -5876,8 +5765,6 @@ def _inhouse_compute_unresolved_rasa(
58765765
:return: unresolved RASA
58775766
"""
58785767

5879-
assert self.can_calculate_unresolved_protein_rasa, "`mkdssp` needs to be installed"
5880-
58815768
num_atom = atom_pos.shape[0]
58825769

58835770
chain_mask = asym_id == unresolved_cid
@@ -5949,16 +5836,8 @@ def compute_unresolved_rasa(
59495836
weight = weight_dict.get("unresolved", {}).get("unresolved", None)
59505837
assert weight, "Weight not found for unresolved"
59515838

5952-
# for migrating to a rewritten computation of RSA
5953-
# to remove mkdssp dependency for model selection
5954-
5955-
if self.use_inhouse_rsa_calculation:
5956-
compute_unresolved_rasa_function = self._inhouse_compute_unresolved_rasa
5957-
else:
5958-
compute_unresolved_rasa_function = self._compute_unresolved_rasa
5959-
59605839
unresolved_rasa = [
5961-
compute_unresolved_rasa_function(*args)
5840+
self._compute_unresolved_rasa(*args)
59625841
for args in zip(
59635842
unresolved_cid,
59645843
unresolved_residue_mask,

tests/test_af3.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,12 +1501,7 @@ def test_unresolved_protein_rasa():
15011501

15021502
unresolved_residue_mask = torch.randint(0, 2, asym_id.shape).bool()
15031503

1504-
compute_model_selection_score = ComputeModelSelectionScore(
1505-
use_inhouse_rsa_calculation = True
1506-
)
1507-
1508-
if not compute_model_selection_score.can_calculate_unresolved_protein_rasa:
1509-
pytest.skip("mkdssp not available for calculating unresolved protein rasa")
1504+
compute_model_selection_score = ComputeModelSelectionScore()
15101505

15111506
# only test with protein
15121507

0 commit comments

Comments
 (0)