Skip to content

Commit b93e278

Browse files
committed
complete the average pooling by atom lengths function needed for packed repr of atoms in diffusion module when going from atoms -> tokens
1 parent 54c4cf7 commit b93e278

File tree

2 files changed

+30
-40
lines changed

2 files changed

+30
-40
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,35 +87,33 @@ def inner(t, *args, **kwargs):
8787
return fn(t, *args, **kwargs)
8888
return inner
8989

90-
# Loss functions
90+
# packed atom representation functions
9191

9292
@typecheck
93-
def calc_smooth_lddt_loss(
94-
denoised: Float['b m 3'],
95-
ground_truth: Float['b m 3'],
96-
is_rna_per_atom: Float['b m'],
97-
is_dna_per_atom: Float['b m']
98-
) -> Float[' b']:
99-
100-
m, device = is_rna_per_atom.shape[-1], denoised.device
101-
102-
dx_denoised = torch.cdist(denoised, denoised)
103-
dx_gt = torch.cdist(ground_truth, ground_truth)
104-
105-
ddx = torch.abs(dx_gt - dx_denoised)
106-
eps = 0.25 * (
107-
sigmoid(0.5 - ddx) + sigmoid(1 - ddx) + sigmoid(2 - ddx) + sigmoid(4 - ddx)
108-
)
109-
110-
is_nuc = is_rna_per_atom + is_dna_per_atom
111-
mask = einx.multiply('b i, b j -> b i j', is_nuc, is_nuc)
112-
c = (dx_gt < 30) * mask + (dx_gt < 15) * (1 - mask)
113-
114-
eye = torch.eye(m, device = device)
115-
num = einx.sum('b [...]', c * eps * (1 - eye)) / (m**2 - m)
116-
den = einx.sum('b [...]', c * (1 - eye)) / (m**2 - m)
93+
def mean_pool_with_lens(
94+
feats: Float['b m d'],
95+
lens: Int['b n']
96+
) -> Float['b n d']:
97+
98+
seq_len = feats.shape[1]
99+
100+
mask = lens > 0
101+
assert (lens.sum(dim = -1) <= seq_len).all(), 'one of the lengths given exceeds the total sequence length of the features passed in'
102+
103+
cumsum_feats = feats.cumsum(dim = 1)
104+
cumsum_feats = F.pad(cumsum_feats, (0, 0, 1, 0), value = 0.)
105+
106+
cumsum_indices = lens.cumsum(dim = 1)
107+
cumsum_indices = F.pad(cumsum_indices, (1, 0), value = 0)
108+
109+
sel_cumsum = einx.get_at('b [m] d, b n -> b n d', cumsum_feats, cumsum_indices)
110+
111+
# subtract cumsum at one index from the previous one
112+
summed = sel_cumsum[:, 1:] - sel_cumsum[:, :-1]
117113

118-
return 1. - num/den
114+
avg = einx.divide('b n d, b n', summed, lens.clamp(min = 1))
115+
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
116+
return avg
119117

120118
# linear and outer sum
121119
# for single repr -> pairwise pattern throughout this architecture

tests/test_af3.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,15 @@
2727
)
2828

2929
from alphafold3_pytorch.alphafold3 import (
30-
calc_smooth_lddt_loss
30+
mean_pool_with_lens
3131
)
3232

33-
def test_calc_smooth_lddt_loss():
34-
denoised = torch.randn(8, 100, 3)
35-
ground_truth = torch.randn(8, 100, 3)
36-
is_rna_per_atom = torch.randint(0, 2, (8, 100)).float()
37-
is_dna_per_atom = torch.randint(0, 2, (8, 100)).float()
38-
39-
loss = calc_smooth_lddt_loss(
40-
denoised,
41-
ground_truth,
42-
is_rna_per_atom,
43-
is_dna_per_atom
44-
)
33+
def test_mean_pool_with_lens():
34+
seq = torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [1.], [1.]]])
35+
lens = torch.tensor([[3, 4, 2]]).long()
36+
pooled = mean_pool_with_lens(seq, lens)
4537

46-
assert torch.all(loss <= 1) and torch.all(loss >= 0)
38+
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
4739

4840
def test_smooth_lddt_loss():
4941
pred_coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)