@@ -841,81 +841,79 @@ class RelativePositionEncoding(Module):
841841
842842 def __init__ (
843843 self ,
844+ * ,
844845 r_max = 32 ,
845846 s_max = 2 ,
846- out_dim = 128
847+ dim_out = 128
847848 ):
848849 super ().__init__ ()
849850 self .r_max = r_max
850851 self .s_max = s_max
851852
852- input_dim = (2 * r_max + 2 ) + (2 * r_max + 2 ) + 1 + (2 * s_max + 2 )
853- self .out_embedder = LinearNoBias (input_dim , out_dim )
854-
853+ dim_input = (2 * r_max + 2 ) + (2 * r_max + 2 ) + 1 + (2 * s_max + 2 )
854+ self .out_embedder = LinearNoBias (dim_input , dim_out )
855+
855856 @typecheck
856857 def forward (
857858 self ,
858859 * ,
859860 additional_residue_feats : Float ['b n rf' ]
860861 ) -> Float ['b n n dp' ]:
862+
863+ device = additional_residue_feats .device
864+ assert additional_residue_feats .shape [- 1 ] >= 5
865+
866+ res_idx , token_idx , asym_id , entity_id , sym_id = additional_residue_feats [..., :5 ].unbind (dim = - 1 )
861867
862- res_idx = additional_residue_feats [..., 0 ]
863- token_idx = additional_residue_feats [..., 1 ]
864- asym_id = additional_residue_feats [..., 2 ]
865- entity_id = additional_residue_feats [..., 3 ]
866- sym_id = additional_residue_feats [..., 4 ]
867-
868- diff_res_idx = rearrange (res_idx , 'b n -> b n 1' ) \
869- - rearrange (res_idx , 'b n -> b 1 n' )
870- diff_token_idx = rearrange (token_idx , 'b n -> b n 1' ) \
871- - rearrange (token_idx , 'b n -> b 1 n' )
872- diff_sym_id = rearrange (sym_id , 'b n -> b n 1' ) \
873- - rearrange (sym_id , 'b n -> b 1 n' )
874- mask_same_chain = rearrange (asym_id , 'b n -> b n 1' ) \
875- - rearrange (asym_id , 'b n -> b 1 n' ) == 0
868+ diff_res_idx = einx .subtract ('b i, b j -> b i j' , res_idx , res_idx )
869+ diff_token_idx = einx .subtract ('b i, b j -> b i j' , token_idx , token_idx )
870+ diff_sym_id = einx .subtract ('b i, b j -> b i j' , sym_id , sym_id )
871+
872+ mask_same_chain = einx .subtract ('b i, b j -> b i j' , asym_id , asym_id ) == 0
876873 mask_same_res = diff_res_idx == 0
877- mask_same_entity = (rearrange (entity_id , 'b n -> b n 1' ) \
878- - rearrange (entity_id , 'b n -> b 1 n' ) == 0 ).unsqueeze (- 1 )
874+ mask_same_entity = einx .subtract ('b i, b j -> b i j 1' , entity_id , entity_id ) == 0
879875
880876 d_res = torch .where (
881877 mask_same_chain ,
882878 torch .clip (diff_res_idx + self .r_max , 0 , 2 * self .r_max ),
883879 2 * self .r_max + 1
884880 )
881+
885882 d_token = torch .where (
886883 mask_same_chain * mask_same_res ,
887884 torch .clip (diff_token_idx + self .r_max , 0 , 2 * self .r_max ),
888885 2 * self .r_max + 1
889886 )
887+
890888 d_chain = torch .where (
891889 ~ mask_same_chain ,
892890 torch .clip (diff_sym_id + self .s_max , 0 , 2 * self .s_max ),
893891 2 * self .s_max + 1
894892 )
895893
896894 def onehot (x , bins ):
897- _ , indexes = (x .view (- 1 , 1 ) - bins .view (1 , - 1 )).abs ().min (dim = 1 )
898- indexes = indexes .type (torch .int64 ).view (- 1 , 1 )
895+ x , packed_shape = pack_one (x , '*' )
896+ dist_from_bins = einx .subtract ('i, j -> i j' , x , bins )
897+ indexes = dist_from_bins .abs ().min (dim = 1 , keepdim = True ).indices
898+ indexes = rearrange (indexes .long (), 'i j -> (i j) 1' )
899899 one_hots = torch .zeros (indexes .shape [0 ], len (bins )).scatter_ (1 , indexes , 1 )
900- out = rearrange (one_hots , '(b n k) d -> b n k d' , n = x .shape [1 ], k = x .shape [2 ])
901- return out
902-
903- a_rel_pos = onehot (d_res , torch .arange (2 * self .r_max + 2 ))
904- a_rel_token = onehot (d_token , torch .arange (2 * self .r_max + 2 ))
905- a_rel_chain = onehot (d_chain , torch .arange (2 * self .s_max + 2 ))
906-
907- p = self .out_embedder (
908- torch .cat ([
909- a_rel_pos ,
910- a_rel_token ,
911- mask_same_entity ,
912- a_rel_chain
913- ], dim = - 1 )
914- )
915-
916- return p
917-
918-
900+ return unpack_one (one_hots , packed_shape , '* d' )
901+
902+ r_arange = torch .arange (2 * self .r_max + 2 , device = device )
903+ s_arange = torch .arange (2 * self .s_max + 2 , device = device )
904+
905+ a_rel_pos = onehot (d_res , r_arange )
906+ a_rel_token = onehot (d_token , r_arange )
907+ a_rel_chain = onehot (d_chain , s_arange )
908+
909+ out , _ = pack ((
910+ a_rel_pos ,
911+ a_rel_token ,
912+ mask_same_entity ,
913+ a_rel_chain
914+ ), 'b i j *' )
915+
916+ return self .out_embedder (out )
919917
920918class TemplateEmbedder (Module ):
921919 """ Algorithm 16 """
@@ -2021,6 +2019,10 @@ def __init__(
20212019 pair_bias_attn_heads = 16 ,
20222020 dropout_row_prob = 0.25 ,
20232021 pairwise_block_kwargs = dict ()
2022+ ),
2023+ relative_position_encoding_kwargs : dict = dict (
2024+ r_max = 32 ,
2025+ s_max = 2 ,
20242026 )
20252027 ):
20262028 super ().__init__ ()
@@ -2043,6 +2045,15 @@ def __init__(
20432045
20442046 dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
20452047
2048+ # relative positional encoding
2049+ # used by pairwise in main alphafold2 trunk
2050+ # and also in the diffusion module separately from alphafold3
2051+
2052+ self .relative_position_encoding = RelativePositionEncoding (
2053+ dim_out = dim_pairwise ,
2054+ ** relative_position_encoding_kwargs
2055+ )
2056+
20462057 # templates
20472058
20482059 self .template_embedder = TemplateEmbedder (
@@ -2130,6 +2141,8 @@ def forward(
21302141 resolved_labels : Int ['b n' ] | None = None ,
21312142 ) -> Float ['b m 3' ] | Float ['' ]:
21322143
2144+ w = self .atoms_per_window
2145+
21332146 # embed inputs
21342147
21352148 (
@@ -2145,7 +2158,15 @@ def forward(
21452158 additional_residue_feats = additional_residue_feats
21462159 )
21472160
2148- w = self .atoms_per_window
2161+ # relative positional encoding
2162+
2163+ relative_position_encoding = self .relative_position_encoding (
2164+ additional_residue_feats = additional_residue_feats
2165+ )
2166+
2167+ pairwise_init = pairwise_init + relative_position_encoding
2168+
2169+ # pairwise mask
21492170
21502171 mask = reduce (atom_mask , 'b (n w) -> b n' , w = w , reduction = 'any' )
21512172 pairwise_mask = einx .logical_and ('b i, b j -> b i j' , mask , mask )
0 commit comments