Skip to content

Commit e3c313c

Browse files
authored
Add support for plotting plDDTs in output mmCIFs for visualization (#178)
* Update mmcif_writing.py * Update alphafold3.py * Update alphafold3.py * Update alphafold3.py
1 parent 0cc2956 commit e3c313c

File tree

2 files changed

+47
-51
lines changed

2 files changed

+47
-51
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4286,7 +4286,8 @@ def _protein_structure_from_feature(
42864286

42874287
return builder.get_structure()
42884288

4289-
ScoredSample = Tuple[int, Float["b m 3"], Float[" b"], Float[" b"]] # type: ignore
4289+
Sample = Tuple[Float["b m 3"], Float["b pde n n"], Float["b m"], Float["b dist n n"]]
4290+
ScoredSample = Tuple[int, Float["b m 3"], Float["b m"], Float[" b"], Float[" b"]]
42904291

42914292
class ScoreDetails(NamedTuple):
42924293
best_gpde_index: int
@@ -4799,27 +4800,22 @@ def compute_unresolved_rasa(
47994800
def compute_model_selection_score(
48004801
self,
48014802
batch: BatchedAtomInput,
4802-
samples: List[Tuple[
4803-
Float["b m 3"],
4804-
Float["b pde n n"],
4805-
Float["b dist n n"]
4806-
]],
4803+
samples: List[Sample],
48074804
is_fine_tuning: bool = None,
48084805
return_details: bool = False,
48094806
return_unweighted_scores: bool = False,
48104807
compute_rasa: bool = False,
48114808
unresolved_cid: List[int] | None = None,
4812-
unresolved_residue_mask: Bool["b n"] | None = None,
4809+
unresolved_residue_mask: Bool["b n"] | None = None,
48134810
missing_chain_index: int = -1,
48144811
) -> Float[" b"] | ScoreDetails:
4815-
48164812
"""Compute the model selection score for an input batch and corresponding (sampled) atom
48174813
positions.
48184814
48194815
:param batch: A batch of `AtomInput` data.
48204816
:param samples: A list of sampled atom positions along with their predicted distance errors and labels.
48214817
:param is_fine_tuning: is fine tuning
4822-
:param return_top_model: return the top-ranked sample
4818+
:param return_details: return the top model and its score
48234819
:param return_unweighted_scores: return the unweighted scores (i.e., lDDT)
48244820
:param compute_rasa: compute the relative solvent accessible surface area (RASA) for unresolved proteins
48254821
:param unresolved_cid: unresolved chain ids
@@ -4861,7 +4857,7 @@ def compute_model_selection_score(
48614857
scored_samples: List[ScoredSample] = []
48624858

48634859
for sample_idx, sample in enumerate(samples):
4864-
atom_pos_pred, pde_logits, dist_logits = sample
4860+
atom_pos_pred, pde_logits, plddt, dist_logits = sample
48654861

48664862
weighted_lddt = self.compute_weighted_lddt(
48674863
atom_pos_pred,
@@ -4886,50 +4882,51 @@ def compute_model_selection_score(
48864882
tok_repr_atm_mask,
48874883
)
48884884

4889-
scored_samples.append((sample_idx, atom_pos_pred, weighted_lddt, gpde))
4885+
scored_samples.append((sample_idx, atom_pos_pred, plddt, weighted_lddt, gpde))
48904886

48914887
# quick collate
48924888

48934889
*_, all_weighted_lddt, all_gpde = zip(*scored_samples)
48944890

4895-
# rank by batch-averaged gPDE
4891+
# rank by batch-averaged minimum gPDE
48964892

4897-
best_gpde_index = torch.stack(all_gpde).mean(dim = -1).argmax().item()
4893+
best_gpde_index = torch.stack(all_gpde).mean(dim=-1).argmin().item()
48984894

4899-
# rank by batch-averaged lDDT
4895+
# rank by batch-averaged maximum lDDT
49004896

4901-
best_lddt_index = torch.stack(all_weighted_lddt).mean(dim = -1).argmax().item()
4897+
best_lddt_index = torch.stack(all_weighted_lddt).mean(dim=-1).argmax().item()
49024898

49034899
# some weighted score
49044900

49054901
model_selection_score = (
4906-
scored_samples[best_gpde_index][-2] +
4907-
scored_samples[best_lddt_index][-2]
4902+
scored_samples[best_gpde_index][-2] + scored_samples[best_lddt_index][-2]
49084903
) / 2
49094904

49104905
if not return_details:
49114906
return model_selection_score
49124907

49134908
score_details = ScoreDetails(
4914-
best_gpde_index = best_gpde_index,
4915-
best_lddt_index = best_lddt_index,
4916-
score = model_selection_score,
4917-
scored_samples = scored_samples
4909+
best_gpde_index=best_gpde_index,
4910+
best_lddt_index=best_lddt_index,
4911+
score=model_selection_score,
4912+
scored_samples=scored_samples,
49184913
)
49194914

49204915
return score_details
49214916

49224917
@typecheck
49234918
def forward(
4924-
self,
4925-
alphafolds: Tuple[Alphafold3],
4926-
batched_atom_inputs: BatchedAtomInput,
4927-
**kwargs
4919+
self, alphafolds: Tuple[Alphafold3], batched_atom_inputs: BatchedAtomInput, **kwargs
49284920
) -> Float[" b"] | ScoreDetails:
4921+
"""Make model selections by computing the model selection score.
49294922
4930-
"""
4931-
give this a tuple of all the Alphafolds and a batch of atomic inputs
4932-
it will select the best one by the model selection score by returning the index of the Tuple
4923+
NOTE: Give this function a tuple of `Alphafold3` modules and a batch of atomic inputs, and it will
4924+
select the best module via the model selection score by returning the index of the corresponding tuple.
4925+
4926+
:param alphafolds: Tuple of `Alphafold3` modules
4927+
:param batched_atom_inputs: A batch of `AtomInput` data
4928+
:param kwargs: Additional keyword arguments
4929+
:return: Model selection score
49334930
"""
49344931

49354932
samples = []
@@ -4940,19 +4937,15 @@ def forward(
49404937

49414938
pred_atom_pos, logits = alphafold(
49424939
**batched_atom_inputs.model_forward_dict(),
4943-
return_loss = False,
4944-
return_confidence_head_logits = True,
4945-
return_distogram_head_logits = True
4940+
return_loss=False,
4941+
return_confidence_head_logits=True,
4942+
return_distogram_head_logits=True,
49464943
)
4944+
plddt = self.compute_confidence_score.compute_plddt(logits.plddt)
49474945

4948-
samples.append((pred_atom_pos, logits.pde, logits.distance))
4949-
4946+
samples.append((pred_atom_pos, logits.pde, plddt, logits.distance))
49504947

4951-
scores = self.compute_model_selection_score(
4952-
batched_atom_inputs,
4953-
samples = samples,
4954-
**kwargs
4955-
)
4948+
scores = self.compute_model_selection_score(batched_atom_inputs, samples=samples, **kwargs)
49564949

49574950
return scores
49584951

@@ -6083,11 +6076,11 @@ def forward(
60836076
is_nucleotide = is_rna | is_dna
60846077
is_polymer = is_protein | is_rna | is_dna
60856078

6086-
is_any_nucleotide_pair = einx.logical_and(
6087-
'... i, ... j -> ... i j', torch.ones_like(is_nucleotide), is_nucleotide
6079+
is_any_nucleotide_pair = repeat(
6080+
is_nucleotide, '... j -> ... i j', i=is_nucleotide.shape[-1]
60886081
)
6089-
is_any_polymer_pair = einx.logical_and(
6090-
'... i, ... j -> ... i j', torch.ones_like(is_polymer), is_polymer
6082+
is_any_polymer_pair = repeat(
6083+
is_polymer, '... j -> ... i j', i=is_polymer.shape[-1]
60916084
)
60926085

60936086
inclusion_radius = torch.where(
@@ -6098,10 +6091,8 @@ def forward(
60986091

60996092
is_token_center_atom = torch.zeros_like(atom_pos[..., 0], dtype=torch.bool)
61006093
is_token_center_atom[torch.arange(batch_size).unsqueeze(1), molecule_atom_indices] = True
6101-
is_any_token_center_atom_pair = einx.logical_and(
6102-
'... i, ... j -> ... i j',
6103-
torch.ones_like(is_token_center_atom),
6104-
is_token_center_atom,
6094+
is_any_token_center_atom_pair = repeat(
6095+
is_token_center_atom, '... j -> ... i j', i=is_token_center_atom.shape[-1]
61056096
)
61066097

61076098
# compute masks, avoiding self term

alphafold3_pytorch/data/mmcif_writing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
import numpy as np
44

5-
from alphafold3_pytorch.common.biomolecule import (
6-
_from_mmcif_object,
7-
to_mmcif,
8-
)
5+
from alphafold3_pytorch.common.biomolecule import _from_mmcif_object, to_mmcif
96
from alphafold3_pytorch.data.data_pipeline import get_assembly
107
from alphafold3_pytorch.data.mmcif_parsing import MmcifObject, parse_mmcif_object
118
from alphafold3_pytorch.utils.utils import exists
@@ -27,8 +24,10 @@ def write_mmcif(
2724
insert_orig_atom_names: bool = True,
2825
insert_alphafold_mmcif_metadata: bool = True,
2926
sampled_atom_positions: np.ndarray | None = None,
27+
b_factors: np.ndarray | None = None,
3028
):
31-
"""Write a BioPython `Structure` object to an mmCIF file using an intermediate `Biomolecule` object."""
29+
"""Write a BioPython `Structure` object to an mmCIF file using an intermediate `Biomolecule`
30+
object."""
3231
biomol = (
3332
_from_mmcif_object(mmcif_object)
3433
if "assembly" in mmcif_object.file_id
@@ -41,6 +40,12 @@ def write_mmcif(
4140
f"but got {sampled_atom_positions.shape}."
4241
)
4342
biomol.atom_positions[atom_mask] = sampled_atom_positions
43+
if exists(b_factors):
44+
assert biomol.b_factors[atom_mask].shape == b_factors.shape, (
45+
f"Expected B-factors to have shape {biomol.b_factors[atom_mask].shape}, "
46+
f"but got {b_factors.shape}."
47+
)
48+
biomol.b_factors[atom_mask] = b_factors
4449
unique_res_atom_names = biomol.unique_res_atom_names if insert_orig_atom_names else None
4550
mmcif_string = to_mmcif(
4651
biomol,

0 commit comments

Comments
 (0)