Skip to content

Commit 854d606

Browse files
committed
address a memory issue with OuterProductMean raised by Sergey, taking care of both masked and unmasked versions
1 parent 64fc0b1 commit 854d606

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,21 @@ def forward(
801801

802802
a, b = self.to_hidden(msa).chunk(2, dim = -1)
803803

804-
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e s')
805-
806804
# maybe masked mean for outer product
807805

808806
if exists(msa_mask):
809-
outer_product = einx.multiply('b i j d e s, b s -> b i j d e s', outer_product, msa_mask.float())
807+
a = einx.multiply('b s i d, b s -> b s i d', a, msa_mask.float())
808+
b = einx.multiply('b s j e, b s -> b s j e', b, msa_mask.float())
809+
810+
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e')
810811

811-
num = reduce(outer_product, '... s -> ...', 'sum')
812-
den = reduce(msa_mask.float(), '... s -> ...', 'sum')
812+
num_msa = reduce(msa_mask.float(), '... s -> ...', 'sum')
813813

814-
outer_product_mean = einx.divide('b i j d e, b', num, den.clamp(min = self.eps))
814+
outer_product_mean = einx.divide('b i j d e, b', outer_product, num_msa.clamp(min = self.eps))
815815
else:
816-
outer_product_mean = reduce(outer_product, '... s -> ...', 'mean')
816+
num_msa = msa.shape[1]
817+
outer_product = einsum(a, b, 'b s i d, b s j e -> b i j d e')
818+
outer_product_mean = outer_product / num_msa
817819

818820
# flatten
819821

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.65"
3+
version = "0.1.66"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)