|
157 | 157 |
|
158 | 158 | # constants |
159 | 159 |
|
160 | | -SCORED_SAMPLE = Tuple[int, Float["b m 3"], Float[" b"], Float[" b"]] # type: ignore |
161 | | - |
162 | 160 | LinearNoBias = partial(Linear, bias = False) |
163 | 161 |
|
164 | 162 | # always use non reentrant checkpointing |
@@ -3470,6 +3468,13 @@ class ConfidenceHeadLogits(NamedTuple): |
3470 | 3468 | plddt: Float['b plddt m'] |
3471 | 3469 | resolved: Float['b 2 m'] |
3472 | 3470 |
|
| 3471 | +class Alphafold3Logits(NamedTuple): |
| 3472 | + pae: Float['b pae n n'] | None |
| 3473 | + pde: Float['b pde n n'] |
| 3474 | + plddt: Float['b plddt m'] |
| 3475 | + resolved: Float['b 2 m'] |
| 3476 | + distance: Float['b dist m m'] | Float['b dist n n'] | None |
| 3477 | + |
3473 | 3478 | class ConfidenceHead(Module): |
3474 | 3479 | """ Algorithm 31 """ |
3475 | 3480 |
|
@@ -4281,6 +4286,13 @@ def _protein_structure_from_feature( |
4281 | 4286 |
|
4282 | 4287 | return builder.get_structure() |
4283 | 4288 |
|
| 4289 | +ScoredSample = Tuple[int, Float["b m 3"], Float[" b"], Float[" b"]] # type: ignore |
| 4290 | + |
| 4291 | +class ScoreDetails(NamedTuple): |
| 4292 | + best_gpde_index: int |
| 4293 | + best_lddt_index: int |
| 4294 | + score: Float[' b'] |
| 4295 | + scored_samples: ScoredSample |
4284 | 4296 |
|
4285 | 4297 | class ComputeModelSelectionScore(Module): |
4286 | 4298 | """Compute model selection score.""" |
@@ -4787,15 +4799,20 @@ def compute_unresolved_rasa( |
4787 | 4799 | def compute_model_selection_score( |
4788 | 4800 | self, |
4789 | 4801 | batch: BatchedAtomInput, |
4790 | | - samples: List[Tuple[Float["b m 3"], Float["b pde n n"], Float["b dist n n"]]], |
| 4802 | + samples: List[Tuple[ |
| 4803 | + Float["b m 3"], |
| 4804 | + Float["b pde n n"], |
| 4805 | + Float["b dist n n"] |
| 4806 | + ]], |
4791 | 4807 | is_fine_tuning: bool = None, |
4792 | | - return_top_model: bool = False, |
| 4808 | + return_details: bool = False, |
4793 | 4809 | return_unweighted_scores: bool = False, |
4794 | 4810 | compute_rasa: bool = False, |
4795 | 4811 | unresolved_cid: List[int] | None = None, |
4796 | 4812 | unresolved_residue_mask: Bool["b n"] | None = None, |
4797 | 4813 | missing_chain_index: int = -1, |
4798 | | - ) -> Float[" b"] | Tuple[Float[" b"], SCORED_SAMPLE]: |
| 4814 | + ) -> Float[" b"] | ScoreDetails: |
| 4815 | + |
4799 | 4816 | """Compute the model selection score for an input batch and corresponding (sampled) atom |
4800 | 4817 | positions. |
4801 | 4818 |
|
@@ -4841,7 +4858,7 @@ def compute_model_selection_score( |
4841 | 4858 |
|
4842 | 4859 | # score samples |
4843 | 4860 |
|
4844 | | - scored_samples: List[SCORED_SAMPLE] = [] |
| 4861 | + scored_samples: List[ScoredSample] = [] |
4845 | 4862 |
|
4846 | 4863 | for sample_idx, sample in enumerate(samples): |
4847 | 4864 | atom_pos_pred, pde_logits, dist_logits = sample |
@@ -4871,19 +4888,73 @@ def compute_model_selection_score( |
4871 | 4888 |
|
4872 | 4889 | scored_samples.append((sample_idx, atom_pos_pred, weighted_lddt, gpde)) |
4873 | 4890 |
|
4874 | | - top_ranked_sample = max( |
4875 | | - scored_samples, key=lambda x: x[-1].mean() |
4876 | | - ) # rank by batch-averaged gPDE |
4877 | | - best_of_5_sample = max( |
4878 | | - scored_samples, key=lambda x: x[-2].mean() |
4879 | | - ) # rank by batch-averaged lDDT |
| 4891 | + # quick collate |
| 4892 | + |
| 4893 | + *_, all_weighted_lddt, all_gpde = zip(*scored_samples) |
| 4894 | + |
| 4895 | + # rank by batch-averaged gPDE |
4880 | 4896 |
|
4881 | | - model_selection_score = (top_ranked_sample[-2] + best_of_5_sample[-2]) / 2 |
| 4897 | + best_gpde_index = torch.stack(all_gpde).mean(dim = -1).topk(1).indices.item() |
4882 | 4898 |
|
4883 | | - if return_top_model: |
4884 | | - return model_selection_score, top_ranked_sample |
| 4899 | + # rank by batch-averaged lDDT |
4885 | 4900 |
|
4886 | | - return model_selection_score |
| 4901 | + best_lddt_index = torch.stack(all_weighted_lddt).mean(dim = -1).topk(1).indices.item() |
| 4902 | + |
| 4903 | + # some weighted score |
| 4904 | + |
| 4905 | + model_selection_score = ( |
| 4906 | + scored_samples[best_gpde_index][-2] + |
| 4907 | + scored_samples[best_lddt_index][-2] |
| 4908 | + ) / 2 |
| 4909 | + |
| 4910 | + if not return_details: |
| 4911 | + return model_selection_score |
| 4912 | + |
| 4913 | + score_details = ScoreDetails( |
| 4914 | + best_gpde_index = best_gpde_index, |
| 4915 | + best_lddt_index = best_lddt_index, |
| 4916 | + score = model_selection_score, |
| 4917 | + scored_samples = scored_samples |
| 4918 | + ) |
| 4919 | + |
| 4920 | + return score_details |
| 4921 | + |
| 4922 | + @typecheck |
| 4923 | + def forward( |
| 4924 | + self, |
| 4925 | + alphafolds: Tuple[Alphafold3], |
| 4926 | + batched_atom_inputs: BatchedAtomInput, |
| 4927 | + **kwargs |
| 4928 | + ) -> Float[" b"] | ScoreDetails: |
| 4929 | + |
| 4930 | + """ |
| 4931 | + give this a tuple of all the Alphafolds and a batch of atomic inputs |
| 4932 | + it will select the best one by the model selection score by returning the index of the Tuple |
| 4933 | + """ |
| 4934 | + |
| 4935 | + samples = [] |
| 4936 | + |
| 4937 | + with torch.no_grad(): |
| 4938 | + for alphafold in alphafolds: |
| 4939 | + alphafold.eval() |
| 4940 | + |
| 4941 | + pred_atom_pos, logits = alphafold( |
| 4942 | + **batched_atom_inputs.model_forward_dict(), |
| 4943 | + return_loss = False, |
| 4944 | + return_confidence_head_logits = True, |
| 4945 | + return_distogram_head_logits = True |
| 4946 | + ) |
| 4947 | + |
| 4948 | + samples.append((pred_atom_pos, logits.pde, logits.distance)) |
| 4949 | + |
| 4950 | + |
| 4951 | + scores = self.compute_model_selection_score( |
| 4952 | + batched_atom_inputs, |
| 4953 | + samples = samples, |
| 4954 | + **kwargs |
| 4955 | + ) |
| 4956 | + |
| 4957 | + return scores |
4887 | 4958 |
|
4888 | 4959 | # main class |
4889 | 4960 |
|
@@ -5383,8 +5454,7 @@ def forward( |
5383 | 5454 | hard_validate: bool = False |
5384 | 5455 | ) -> ( |
5385 | 5456 | Float['b m 3'] | |
5386 | | - Tuple[Float['b m 3'], ConfidenceHeadLogits] | |
5387 | | - Tuple[Float['b m 3'], ConfidenceHeadLogits, Float['b l n n'] | Float['b l m m']] | |
| 5457 | + Tuple[Float['b m 3'], ConfidenceHeadLogits | Alphafold3Logits] | |
5388 | 5458 | Float[''] | |
5389 | 5459 | Tuple[Float[''], LossBreakdown] |
5390 | 5460 | ): |
@@ -5673,16 +5743,17 @@ def forward( |
5673 | 5743 | return_pae_logits = True |
5674 | 5744 | ) |
5675 | 5745 |
|
5676 | | - if not return_distogram_head_logits: |
5677 | | - return sampled_atom_pos, confidence_head_logits |
| 5746 | + returned_logits = confidence_head_logits |
5678 | 5747 |
|
5679 | | - distogram_head_logits = self.distogram_head(pairwise.clone().detach()) |
| 5748 | + if return_distogram_head_logits: |
| 5749 | + distogram_head_logits = self.distogram_head(pairwise.clone().detach()) |
5680 | 5750 |
|
5681 | | - return ( |
5682 | | - sampled_atom_pos, |
5683 | | - confidence_head_logits, |
5684 | | - distogram_head_logits, |
5685 | | - ) |
| 5751 | + returned_logits = Alphafold3Logits( |
| 5752 | + **confidence_head_logits._asdict(), |
| 5753 | + distance = distogram_head_logits |
| 5754 | + ) |
| 5755 | + |
| 5756 | + return sampled_atom_pos, returned_logits |
5686 | 5757 |
|
5687 | 5758 | # if being forced to return loss, but do not have sufficient information to return losses, just return 0 |
5688 | 5759 |
|
|
0 commit comments