Skip to content

Commit 22447b5

Browse files
committed
tiny cleanup
1 parent 9aac7c2 commit 22447b5

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,19 @@ def unpack_one(to_unpack, unpack_pattern = None):
170170
def exclusive_cumsum(t, dim = -1):
171171
return t.cumsum(dim = dim) - t
172172

173+
@typecheck
174+
def masked_average(
175+
t: Shaped['...'],
176+
mask: Shaped['...'],
177+
*,
178+
dim: int | Tuple[int, ...],
179+
eps = 1.
180+
) -> Float['...']:
181+
182+
num = (t * mask).sum(dim = dim)
183+
den = mask.sum(dim = dim)
184+
return num / den.clamp(min = eps)
185+
173186
# checkpointing utils
174187

175188
@typecheck
@@ -3565,11 +3578,10 @@ def compute_ptm(
35653578
mask_j = (asym_id[b] == chain_j)
35663579
pair_mask = einx.multiply('i, j -> i j', mask_i, mask_j)
35673580

3568-
pair_residue_weights = pair_mask * (
3569-
residue_weights[b, None, :] * residue_weights[b, :, None])
3581+
pair_residue_weights = pair_mask * einx.multiply('... i, ... j -> ... i j', residue_weights[b], residue_weights[b])
35703582

35713583
if pair_residue_weights.sum() == 0:
3572-
# chain i or chain j doesnot have any valid frame
3584+
# chain i or chain j does not have any valid frame
35733585
continue
35743586

35753587
normed_residue_mask = pair_residue_weights / (self.eps + torch.sum(
@@ -3738,7 +3750,7 @@ def compute_disorder(
37383750
is_protein_mask = atom_is_molecule_types[..., IS_PROTEIN_INDEX]
37393751
mask = atom_mask * is_protein_mask
37403752

3741-
atom_rasa = 1 - plddt
3753+
atom_rasa = 1. - plddt
37423754

37433755
disorder = ( (atom_rasa > 0.581) * mask ).sum(dim=-1) / ( self.eps + mask.sum(dim=1))
37443756
return disorder
@@ -3864,7 +3876,7 @@ def compute_modified_residue_score(
38643876
plddt = self.compute_confidence_score.compute_plddt(confidence_head_logits.plddt)
38653877

38663878
mask = atom_is_modified_residue * atom_mask
3867-
plddt_mean = (plddt * mask).sum(dim=-1) / ( self.eps + mask.sum(dim=-1))
3879+
plddt_mean = masked_average(plddt, mask, dim = -1, eps = self.eps)
38683880

38693881
return plddt_mean
38703882

@@ -3991,7 +4003,7 @@ def compute_gpde(
39914003
contact_prob = contact_prob * mask
39924004

39934005
# Section 5.7 equation 16
3994-
gpde = einsum(contact_prob * pde, 'b i j -> b') / einsum(contact_prob, 'b i j -> b').clamp(min=1.)
4006+
gpde = masked_average(pde, contact_prob, dim = (-1, -2))
39954007

39964008
return gpde
39974009

@@ -4050,9 +4062,7 @@ def compute_lddt(
40504062
mask = mask * pairwise_mask
40514063

40524064
# Calculate masked averaging
4053-
lddt_sum = (lddt * mask).sum(dim=(-1, -2))
4054-
lddt_count = mask.sum(dim=(-1, -2))
4055-
lddt_mean = lddt_sum / lddt_count.clamp(min=1)
4065+
lddt_mean = masked_average(lddt, mask, dim = (-1, -2))
40564066

40574067
return lddt_mean
40584068

0 commit comments

Comments
 (0)