@@ -4286,7 +4286,8 @@ def _protein_structure_from_feature(
42864286
42874287 return builder .get_structure ()
42884288
4289- ScoredSample = Tuple [int , Float ["b m 3" ], Float [" b" ], Float [" b" ]] # type: ignore
4289+ Sample = Tuple [Float ["b m 3" ], Float ["b pde n n" ], Float ["b m" ], Float ["b dist n n" ]]
4290+ ScoredSample = Tuple [int , Float ["b m 3" ], Float ["b m" ], Float [" b" ], Float [" b" ]]
42904291
42914292class ScoreDetails (NamedTuple ):
42924293 best_gpde_index : int
@@ -4799,27 +4800,22 @@ def compute_unresolved_rasa(
47994800 def compute_model_selection_score (
48004801 self ,
48014802 batch : BatchedAtomInput ,
4802- samples : List [Tuple [
4803- Float ["b m 3" ],
4804- Float ["b pde n n" ],
4805- Float ["b dist n n" ]
4806- ]],
4803+ samples : List [Sample ],
48074804 is_fine_tuning : bool = None ,
48084805 return_details : bool = False ,
48094806 return_unweighted_scores : bool = False ,
48104807 compute_rasa : bool = False ,
48114808 unresolved_cid : List [int ] | None = None ,
4812- unresolved_residue_mask : Bool ["b n" ] | None = None ,
4809+ unresolved_residue_mask : Bool ["b n" ] | None = None ,
48134810 missing_chain_index : int = - 1 ,
48144811 ) -> Float [" b" ] | ScoreDetails :
4815-
48164812 """Compute the model selection score for an input batch and corresponding (sampled) atom
48174813 positions.
48184814
48194815 :param batch: A batch of `AtomInput` data.
48204816 :param samples: A list of sampled atom positions along with their predicted distance errors and labels.
48214817 :param is_fine_tuning: is fine tuning
4822- :param return_top_model : return the top-ranked sample
4818+ :param return_details : return the top model and its score
48234819 :param return_unweighted_scores: return the unweighted scores (i.e., lDDT)
48244820 :param compute_rasa: compute the relative solvent accessible surface area (RASA) for unresolved proteins
48254821 :param unresolved_cid: unresolved chain ids
@@ -4861,7 +4857,7 @@ def compute_model_selection_score(
48614857 scored_samples : List [ScoredSample ] = []
48624858
48634859 for sample_idx , sample in enumerate (samples ):
4864- atom_pos_pred , pde_logits , dist_logits = sample
4860+ atom_pos_pred , pde_logits , plddt , dist_logits = sample
48654861
48664862 weighted_lddt = self .compute_weighted_lddt (
48674863 atom_pos_pred ,
@@ -4886,50 +4882,51 @@ def compute_model_selection_score(
48864882 tok_repr_atm_mask ,
48874883 )
48884884
4889- scored_samples .append ((sample_idx , atom_pos_pred , weighted_lddt , gpde ))
4885+ scored_samples .append ((sample_idx , atom_pos_pred , plddt , weighted_lddt , gpde ))
48904886
48914887 # quick collate
48924888
48934889 * _ , all_weighted_lddt , all_gpde = zip (* scored_samples )
48944890
4895- # rank by batch-averaged gPDE
4891+ # rank by batch-averaged minimum gPDE
48964892
4897- best_gpde_index = torch .stack (all_gpde ).mean (dim = - 1 ).argmax ().item ()
4893+ best_gpde_index = torch .stack (all_gpde ).mean (dim = - 1 ).argmin ().item ()
48984894
4899- # rank by batch-averaged lDDT
4895+ # rank by batch-averaged maximum lDDT
49004896
4901- best_lddt_index = torch .stack (all_weighted_lddt ).mean (dim = - 1 ).argmax ().item ()
4897+ best_lddt_index = torch .stack (all_weighted_lddt ).mean (dim = - 1 ).argmax ().item ()
49024898
49034899 # some weighted score
49044900
49054901 model_selection_score = (
4906- scored_samples [best_gpde_index ][- 2 ] +
4907- scored_samples [best_lddt_index ][- 2 ]
4902+ scored_samples [best_gpde_index ][- 2 ] + scored_samples [best_lddt_index ][- 2 ]
49084903 ) / 2
49094904
49104905 if not return_details :
49114906 return model_selection_score
49124907
49134908 score_details = ScoreDetails (
4914- best_gpde_index = best_gpde_index ,
4915- best_lddt_index = best_lddt_index ,
4916- score = model_selection_score ,
4917- scored_samples = scored_samples
4909+ best_gpde_index = best_gpde_index ,
4910+ best_lddt_index = best_lddt_index ,
4911+ score = model_selection_score ,
4912+ scored_samples = scored_samples ,
49184913 )
49194914
49204915 return score_details
49214916
49224917 @typecheck
49234918 def forward (
4924- self ,
4925- alphafolds : Tuple [Alphafold3 ],
4926- batched_atom_inputs : BatchedAtomInput ,
4927- ** kwargs
4919+ self , alphafolds : Tuple [Alphafold3 ], batched_atom_inputs : BatchedAtomInput , ** kwargs
49284920 ) -> Float [" b" ] | ScoreDetails :
4921+ """Make model selections by computing the model selection score.
49294922
4930- """
4931- give this a tuple of all the Alphafolds and a batch of atomic inputs
4932- it will select the best one by the model selection score by returning the index of the Tuple
4923+ NOTE: Give this function a tuple of `Alphafold3` modules and a batch of atomic inputs, and it will
4924+ select the best module via the model selection score by returning the index of the corresponding tuple.
4925+
4926+ :param alphafolds: Tuple of `Alphafold3` modules
4927+ :param batched_atom_inputs: A batch of `AtomInput` data
4928+ :param kwargs: Additional keyword arguments
4929+ :return: Model selection score
49334930 """
49344931
49354932 samples = []
@@ -4940,19 +4937,15 @@ def forward(
49404937
49414938 pred_atom_pos , logits = alphafold (
49424939 ** batched_atom_inputs .model_forward_dict (),
4943- return_loss = False ,
4944- return_confidence_head_logits = True ,
4945- return_distogram_head_logits = True
4940+ return_loss = False ,
4941+ return_confidence_head_logits = True ,
4942+ return_distogram_head_logits = True ,
49464943 )
4944+ plddt = self .compute_confidence_score .compute_plddt (logits .plddt )
49474945
4948- samples .append ((pred_atom_pos , logits .pde , logits .distance ))
4949-
4946+ samples .append ((pred_atom_pos , logits .pde , plddt , logits .distance ))
49504947
4951- scores = self .compute_model_selection_score (
4952- batched_atom_inputs ,
4953- samples = samples ,
4954- ** kwargs
4955- )
4948+ scores = self .compute_model_selection_score (batched_atom_inputs , samples = samples , ** kwargs )
49564949
49574950 return scores
49584951
@@ -6083,11 +6076,11 @@ def forward(
60836076 is_nucleotide = is_rna | is_dna
60846077 is_polymer = is_protein | is_rna | is_dna
60856078
6086- is_any_nucleotide_pair = einx . logical_and (
6087- '... i, ... j -> ... i j' , torch . ones_like ( is_nucleotide ), is_nucleotide
6079+ is_any_nucleotide_pair = repeat (
6080+ is_nucleotide , ' ... j -> ... i j' , i = is_nucleotide . shape [ - 1 ]
60886081 )
6089- is_any_polymer_pair = einx . logical_and (
6090- '... i, ... j -> ... i j' , torch . ones_like ( is_polymer ), is_polymer
6082+ is_any_polymer_pair = repeat (
6083+ is_polymer , ' ... j -> ... i j' , i = is_polymer . shape [ - 1 ]
60916084 )
60926085
60936086 inclusion_radius = torch .where (
@@ -6098,10 +6091,8 @@ def forward(
60986091
60996092 is_token_center_atom = torch .zeros_like (atom_pos [..., 0 ], dtype = torch .bool )
61006093 is_token_center_atom [torch .arange (batch_size ).unsqueeze (1 ), molecule_atom_indices ] = True
6101- is_any_token_center_atom_pair = einx .logical_and (
6102- '... i, ... j -> ... i j' ,
6103- torch .ones_like (is_token_center_atom ),
6104- is_token_center_atom ,
6094+ is_any_token_center_atom_pair = repeat (
6095+ is_token_center_atom , '... j -> ... i j' , i = is_token_center_atom .shape [- 1 ]
61056096 )
61066097
61076098 # compute masks, avoiding self term
0 commit comments