@@ -3683,7 +3683,8 @@ def compute_ptm(
36833683 molecule_atom_lens : Int ["b n" ] | None = None ,
36843684 interface : bool = False ,
36853685 compute_chain_wise_iptm : bool = False ,
3686- ):
3686+ ) -> Float [" b" ] | Tuple [Float ["b chains chains" ], Bool ["b chains chains" ], Int ["b chains" ]]:
3687+
36873688 """Compute pTM from logits.
36883689
36893690 :param logits: [b c n n] logits
@@ -3759,22 +3760,22 @@ def compute_ptm(
37593760 if pair_residue_weights .sum () == 0 :
37603761 # chain i or chain j does not have any valid frame
37613762 continue
3762- else :
3763- normed_residue_mask = pair_residue_weights / (
3764- self .eps
3765- + torch .sum (pair_residue_weights , dim = - 1 , keepdims = True )
3766- )
37673763
3768- masked_predicted_tm_term = predicted_tm_term [b ] * pair_mask
3764+ normed_residue_mask = pair_residue_weights / (
3765+ self .eps
3766+ + torch .sum (pair_residue_weights , dim = - 1 , keepdims = True )
3767+ )
3768+
3769+ masked_predicted_tm_term = predicted_tm_term [b ] * pair_mask
37693770
3770- per_alignment = torch .sum (
3771- masked_predicted_tm_term * normed_residue_mask , dim = - 1
3772- )
3773- weighted_argmax = (residue_weights [b ] * per_alignment ).argmax ()
3774- chain_wise_iptm [b , i , j ] = per_alignment [weighted_argmax ]
3775- chain_wise_iptm_mask [b , i , j ] = True
3771+ per_alignment = torch .sum (
3772+ masked_predicted_tm_term * normed_residue_mask , dim = - 1
3773+ )
3774+ weighted_argmax = (residue_weights [b ] * per_alignment ).argmax ()
3775+ chain_wise_iptm [b , i , j ] = per_alignment [weighted_argmax ]
3776+ chain_wise_iptm_mask [b , i , j ] = True
37763777
3777- return chain_wise_iptm , chain_wise_iptm_mask , unique_chains
3778+ return chain_wise_iptm , chain_wise_iptm_mask , torch . tensor ( unique_chains )
37783779
37793780 else :
37803781 pair_mask = torch .ones (size = (num_batch , num_res , num_res ), device = logits .device ).bool ()
@@ -3974,7 +3975,8 @@ def compute_full_complex_metric(
39743975 atom_mask : Bool ["b m" ],
39753976 is_molecule_types : Bool [f"b n { IS_MOLECULE_TYPES } " ],
39763977 return_confidence_score : bool = False ,
3977- ) -> Float [" b" ] | Tuple [Float [" b" ], Tuple [ConfidenceScore , Bool [" b" ]]]:
3978+ ) -> Float [" b" ] | Tuple [Float [" b" ], Tuple [ConfidenceScore , Bool [" b" ]]]:
3979+
39783980 """Compute full complex metric.
39793981
39803982 :param confidence_head_logits: ConfidenceHeadLogits
@@ -4105,7 +4107,7 @@ def compute_interface_metric(
41054107
41064108 for b , chains in enumerate (interface_chains ):
41074109 for chain in chains :
4108- idx = unique_chains [b ].index (chain )
4110+ idx = unique_chains [b ].tolist (). index (chain )
41094111 interface_metric [b ] += iptm_sum [b , idx ].sum () / iptm_count [b , idx ].sum ().clamp (
41104112 min = 1
41114113 )
0 commit comments