@@ -3309,13 +3309,20 @@ class ComputeRankingScore(Module):
33093309
33103310 def __init__ (
33113311 self ,
3312- eps = 1e-8
3312+ eps = 1e-8 ,
3313+ score_iptm_weight = 0.8 ,
3314+ score_ptm_weight = 0.2 ,
3315+ score_disorder_weight = 0.5
33133316 ):
33143317 super ().__init__ ()
33153318 self .eps = eps
33163319 self .compute_clash = ComputeClash ()
33173320 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
33183321
3322+ self .score_iptm_weight = score_iptm_weight
3323+ self .score_ptm_weight = score_ptm_weight
3324+ self .score_disorder_weight = score_disorder_weight
3325+
33193326 @typecheck
33203327 def compute_disorder (
33213328 self ,
@@ -3342,7 +3349,8 @@ def compute_full_complex_metric(
33423349 atom_pos : Float ['b m 3' ],
33433350 atom_mask : Bool ['b m' ],
33443351 is_molecule_types : Bool [f'b n { IS_MOLECULE_TYPES } ' ],
3345- ) -> Float [' b' ]:
3352+ return_confidence_score : bool = False
3353+ ) -> Float [' b' ] | Tuple [Float [' b' ], Tuple [ConfidenceScore , Bool [' b' ]]]:
33463354
33473355 # Section 5.9.3.1
33483356
@@ -3372,9 +3380,18 @@ def compute_full_complex_metric(
33723380 disorder = self .compute_disorder (confidence_score .plddt , atom_mask , atom_is_molecule_types )
33733381
33743382 # Section 5.9.3 equation 19
3375- score = 0.8 * confidence_score .iptm + 0.2 * confidence_score .ptm + 0.5 * disorder - 100 * has_clash
33763383
3377- return score
3384+ weighted_score = (
3385+ confidence_score .iptm * self .score_iptm_weight +
3386+ confidence_score .ptm * self .score_ptm_weight +
3387+ disorder * self .score_disorder_weight
3388+ - 100 * has_clash
3389+ )
3390+
3391+ if not return_confidence_score :
3392+ return weighted_score
3393+
3394+ return weighted_score , (confidence_score , has_clash )
33783395
33793396 @typecheck
33803397 def compute_single_chain_metric (
0 commit comments