@@ -170,6 +170,19 @@ def unpack_one(to_unpack, unpack_pattern = None):
170170def exclusive_cumsum (t , dim = - 1 ):
171171 return t .cumsum (dim = dim ) - t
172172
173+ @typecheck
174+ def masked_average (
175+ t : Shaped ['...' ],
176+ mask : Shaped ['...' ],
177+ * ,
178+ dim : int | Tuple [int , ...],
179+ eps = 1.
180+ ) -> Float ['...' ]:
181+
182+ num = (t * mask ).sum (dim = dim )
183+ den = mask .sum (dim = dim )
184+ return num / den .clamp (min = eps )
185+
173186# checkpointing utils
174187
175188@typecheck
@@ -3565,11 +3578,10 @@ def compute_ptm(
35653578 mask_j = (asym_id [b ] == chain_j )
35663579 pair_mask = einx .multiply ('i, j -> i j' , mask_i , mask_j )
35673580
3568- pair_residue_weights = pair_mask * (
3569- residue_weights [b , None , :] * residue_weights [b , :, None ])
3581+ pair_residue_weights = pair_mask * einx .multiply ('... i, ... j -> ... i j' , residue_weights [b ], residue_weights [b ])
35703582
35713583 if pair_residue_weights .sum () == 0 :
3572- # chain i or chain j doesnot have any valid frame
3584+ # chain i or chain j does not have any valid frame
35733585 continue
35743586
35753587 normed_residue_mask = pair_residue_weights / (self .eps + torch .sum (
@@ -3738,7 +3750,7 @@ def compute_disorder(
37383750 is_protein_mask = atom_is_molecule_types [..., IS_PROTEIN_INDEX ]
37393751 mask = atom_mask * is_protein_mask
37403752
3741- atom_rasa = 1 - plddt
3753+ atom_rasa = 1. - plddt
37423754
37433755 disorder = ( (atom_rasa > 0.581 ) * mask ).sum (dim = - 1 ) / ( self .eps + mask .sum (dim = 1 ))
37443756 return disorder
@@ -3864,7 +3876,7 @@ def compute_modified_residue_score(
38643876 plddt = self .compute_confidence_score .compute_plddt (confidence_head_logits .plddt )
38653877
38663878 mask = atom_is_modified_residue * atom_mask
3867- plddt_mean = (plddt * mask ). sum ( dim = - 1 ) / ( self .eps + mask . sum ( dim = - 1 ))
3879+ plddt_mean = masked_average (plddt , mask , dim = - 1 , eps = self .eps )
38683880
38693881 return plddt_mean
38703882
@@ -3991,7 +4003,7 @@ def compute_gpde(
39914003 contact_prob = contact_prob * mask
39924004
39934005 # Section 5.7 equation 16
3994- gpde = einsum ( contact_prob * pde , 'b i j -> b' ) / einsum ( contact_prob , 'b i j -> b' ). clamp ( min = 1. )
4006+ gpde = masked_average ( pde , contact_prob , dim = ( - 1 , - 2 ) )
39954007
39964008 return gpde
39974009
@@ -4050,9 +4062,7 @@ def compute_lddt(
40504062 mask = mask * pairwise_mask
40514063
40524064 # Calculate masked averaging
4053- lddt_sum = (lddt * mask ).sum (dim = (- 1 , - 2 ))
4054- lddt_count = mask .sum (dim = (- 1 , - 2 ))
4055- lddt_mean = lddt_sum / lddt_count .clamp (min = 1 )
4065+ lddt_mean = masked_average (lddt , mask , dim = (- 1 , - 2 ))
40564066
40574067 return lddt_mean
40584068
0 commit comments