@@ -801,19 +801,21 @@ def forward(
801801
802802 a , b = self .to_hidden (msa ).chunk (2 , dim = - 1 )
803803
804- outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e s' )
805-
806804 # maybe masked mean for outer product
807805
808806 if exists (msa_mask ):
809- outer_product = einx .multiply ('b i j d e s, b s -> b i j d e s' , outer_product , msa_mask .float ())
807+ a = einx .multiply ('b s i d, b s -> b s i d' , a , msa_mask .float ())
808+ b = einx .multiply ('b s j e, b s -> b s j e' , b , msa_mask .float ())
809+
810+ outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e' )
810811
811- num = reduce (outer_product , '... s -> ...' , 'sum' )
812- den = reduce (msa_mask .float (), '... s -> ...' , 'sum' )
812+ num_msa = reduce (msa_mask .float (), '... s -> ...' , 'sum' )
813813
814- outer_product_mean = einx .divide ('b i j d e, b' , num , den .clamp (min = self .eps ))
814+ outer_product_mean = einx .divide ('b i j d e, b' , outer_product , num_msa .clamp (min = self .eps ))
815815 else :
816- outer_product_mean = reduce (outer_product , '... s -> ...' , 'mean' )
816+ num_msa = msa .shape [1 ]
817+ outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e' )
818+ outer_product_mean = outer_product / num_msa
817819
818820 # flatten
819821
0 commit comments