@@ -1606,7 +1606,22 @@ def forward(
16061606
16071607 if is_unpacked_repr :
16081608 pairwise_repr_cond = repeat (pairwise_repr_cond , 'b i j dp -> b (i w1) (j w2) dp' , w1 = w , w2 = w )
1609- atompair_feats = pairwise_repr_cond + atompair_feats
1609+ else :
1610+ # todo - fix by doing a specialized fn for this
1611+
1612+ repeated_residue_atom_lens = repeat (residue_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr_cond .shape [1 ])
1613+ pairwise_repr_cond , ps = pack_one (pairwise_repr_cond , '* n dp' )
1614+ pairwise_repr_cond = repeat_consecutive_with_lens (pairwise_repr_cond , repeated_residue_atom_lens )
1615+ pairwise_repr_cond = unpack_one (pairwise_repr_cond , ps , '* n dp' )
1616+
1617+ pairwise_repr_cond = rearrange (pairwise_repr_cond , 'b i j dp -> b j i dp' )
1618+ repeated_residue_atom_lens = repeat (residue_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr_cond .shape [1 ])
1619+ pairwise_repr_cond , ps = pack_one (pairwise_repr_cond , '* n dp' )
1620+ pairwise_repr_cond = repeat_consecutive_with_lens (pairwise_repr_cond , repeated_residue_atom_lens )
1621+ pairwise_repr_cond = unpack_one (pairwise_repr_cond , ps , '* n dp' )
1622+ pairwise_repr_cond = rearrange (pairwise_repr_cond , 'b j i dp -> b i j dp' )
1623+
1624+ atompair_feats = pairwise_repr_cond + atompair_feats
16101625
16111626 # condition atompair feats further with single atom repr
16121627
0 commit comments