Skip to content

Commit cbd32f3

Browse files
committed
use einx get_at for making the atom attention biasing more clear
1 parent dece70b commit cbd32f3

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

alphafold3_pytorch/attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,14 @@ def local_attn(
266266
attn_bias[..., 2:, :, :]
267267
), dim = -1)
268268

269-
attn_bias, ps = pack_one(attn_bias, '* i j w1 w2')
269+
# get the diagonal
270270

271-
merged_batch = attn_bias.shape[0]
272-
diag_mask = torch.eye(attn_bias.shape[1], device = device, dtype = torch.bool)
273-
diag_mask = repeat(diag_mask, 'i j -> b i j', b = merged_batch)
271+
n = torch.arange(attn_bias.shape[-3], device = device)
274272

275-
attn_bias = rearrange(attn_bias[diag_mask], '(b n) i j -> b n i j', b = merged_batch)
276-
attn_bias = unpack_one(attn_bias, ps, '* n i j')
273+
attn_bias = einx.get_at(
274+
'... [i j] w1 w2, n, n -> ... n w1 w2',
275+
attn_bias, n, n
276+
)
277277

278278
# carry out attention as usual
279279

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.19"
3+
version = "0.0.20"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)