Skip to content

Commit 53e1661

Browse files
committed
address a memory issue with OuterProductMean raised by Sergey, taking care of both masked and unmasked versions
1 parent 64fc0b1 commit 53e1661

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -795,25 +795,28 @@ def forward(
795795
msa_mask: Bool['b s'] | None = None
796796
) -> Float['b n n dp']:
797797

798+
num_msa = msa.shape[1]
799+
798800
msa = self.norm(msa)
799801

800802
# line 2
801803

802804
a, b = self.to_hidden(msa).chunk(2, dim = -1)
803805

804-
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e s')
805-
806806
# maybe masked mean for outer product
807807

808808
if exists(msa_mask):
809-
outer_product = einx.multiply('b i j d e s, b s -> b i j d e s', outer_product, msa_mask.float())
809+
a = einx.multiply('b s i d, b s -> b s i d', a, msa_mask.float())
810+
b = einx.multiply('b s j e, b s -> b s j e', b, msa_mask.float())
811+
812+
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e')
810813

811-
num = reduce(outer_product, '... s -> ...', 'sum')
812-
den = reduce(msa_mask.float(), '... s -> ...', 'sum')
814+
num_msa = reduce(msa_mask.float(), '... s -> ...', 'sum')
813815

814-
outer_product_mean = einx.divide('b i j d e, b', num, den.clamp(min = self.eps))
816+
outer_product_mean = einx.divide('b i j d e, b', outer_product, num_msa.clamp(min = self.eps))
815817
else:
816-
outer_product_mean = reduce(outer_product, '... s -> ...', 'mean')
818+
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e')
819+
outer_product_mean = outer_product / num_msa
817820

818821
# flatten
819822

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.65"
3+
version = "0.1.66"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)