@@ -2804,7 +2804,7 @@ def forward(
28042804 additional_residue_feats : Float ['b n 10' ],
28052805 residue_atom_lens : Int ['b n' ] | None = None ,
28062806 atom_mask : Bool ['b m' ] | None = None ,
2807- token_bond : Float ['b n n' ] | None = None ,
2807+ token_bond : Bool ['b n n' ] | None = None ,
28082808 msa : Float ['b s n d' ] | None = None ,
28092809 msa_mask : Bool ['b s' ] | None = None ,
28102810 templates : Float ['b t n n dt' ] | None = None ,
@@ -2882,14 +2882,14 @@ def forward(
28822882 # (1) mask out diagonal - token to itself does not count as a bond
28832883 # (2) symmetrize, in case it is not already symmetrical (could also throw an error)
28842884
2885- assert torch . allclose ( token_bond , rearrange (token_bond , 'b i j -> b j i' )), 'token bond must be symmetrical'
2885+ token_bond = token_bond | rearrange (token_bond , 'b i j -> b j i' )
28862886 diagonal = torch .eye (seq_len , device = self .device , dtype = torch .bool )
2887- token_bond .masked_fill_ (diagonal , 0. )
2887+ token_bond .masked_fill_ (diagonal , False )
28882888 else :
28892889 seq_arange = torch .arange (seq_len , device = self .device )
2890- token_bond = ( einx .subtract ('i, j -> i j' , seq_arange , seq_arange ).abs () == 1 ). float ()
2890+ token_bond = einx .subtract ('i, j -> i j' , seq_arange , seq_arange ).abs () == 1
28912891
2892- token_bond_feats = self .token_bond_to_pairwise_feat (token_bond )
2892+ token_bond_feats = self .token_bond_to_pairwise_feat (token_bond . float () )
28932893
28942894 pairwise_init = pairwise_init + token_bond_feats
28952895
0 commit comments