@@ -795,25 +795,28 @@ def forward(
795795 msa_mask : Bool ['b s' ] | None = None
796796 ) -> Float ['b n n dp' ]:
797797
798+ num_msa = msa .shape [1 ]
799+
798800 msa = self .norm (msa )
799801
800802 # line 2
801803
802804 a , b = self .to_hidden (msa ).chunk (2 , dim = - 1 )
803805
804- outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e s' )
805-
806806 # maybe masked mean for outer product
807807
808808 if exists (msa_mask ):
809- outer_product = einx .multiply ('b i j d e s, b s -> b i j d e s' , outer_product , msa_mask .float ())
809+ a = einx .multiply ('b s i d, b s -> b s i d' , a , msa_mask .float ())
810+ b = einx .multiply ('b s j e, b s -> b s j e' , b , msa_mask .float ())
811+
812+ outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e' )
810813
811- num = reduce (outer_product , '... s -> ...' , 'sum' )
812- den = reduce (msa_mask .float (), '... s -> ...' , 'sum' )
814+ num_msa = reduce (msa_mask .float (), '... s -> ...' , 'sum' )
813815
814- outer_product_mean = einx .divide ('b i j d e, b' , num , den .clamp (min = self .eps ))
816+ outer_product_mean = einx .divide ('b i j d e, b' , outer_product , num_msa .clamp (min = self .eps ))
815817 else :
816- outer_product_mean = reduce (outer_product , '... s -> ...' , 'mean' )
818+ outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e' )
819+ outer_product_mean = outer_product / num_msa
817820
818821 # flatten
819822
0 commit comments