Skip to content

Commit b0324e5

Browse files
committed
fix msa_mask gathering when subsampling
1 parent d7a3e1b commit b0324e5

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,13 +1142,12 @@ def forward(
11421142
# msa = einx.get_at('b [s] n dm, b sampled -> b sampled n dm', msa, indices)
11431143

11441144
msa, unpack_one = pack_one(msa, 'b s *')
1145-
indices = repeat(indices, 'b sampled -> b sampled d', d = msa.shape[-1])
1146-
msa = msa.gather(1, indices)
1145+
msa_indices = repeat(indices, 'b sampled -> b sampled d', d = msa.shape[-1])
1146+
msa = msa.gather(1, msa_indices)
11471147
msa = unpack_one(msa)
11481148

11491149
if exists(msa_mask):
11501150
# msa_mask = einx.get_at('b [s], b sampled -> b sampled', msa_mask, indices)
1151-
11521151
msa_mask = msa_mask.gather(1, indices)
11531152

11541153
# account for no msa

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.105"
3+
version = "0.2.106"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def test_msa_module(
228228
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
229229
msa = torch.randn(2, 7, 16, 64)
230230
mask = torch.randint(0, 2, (2, 16)).bool()
231+
msa_mask = torch.randint(0, 2, (2, 7)).bool()
231232

232233
msa_module = MSAModule(
233234
checkpoint = checkpoint,
@@ -238,7 +239,8 @@ def test_msa_module(
238239
msa = msa,
239240
single_repr = single,
240241
pairwise_repr = pairwise,
241-
mask = mask
242+
mask = mask,
243+
msa_mask = msa_mask
242244
)
243245

244246
assert pairwise.shape == pairwise_out.shape

0 commit comments

Comments
 (0)