@@ -938,7 +938,7 @@ def forward(
938938 coarse_pos = repeat (coarse_pos , 'n -> (n q)' , q = self .num_coarse_quantizers )[:coarse_length ]
939939 fine_pos = repeat (fine_pos , 'n -> (n q)' , q = self .num_fine_quantizers )[:fine_length ]
940940
941- coarse_pos = F .pad (coarse_pos , (1 , 0 ), value = - 1 ) # -1 for start token
941+ coarse_pos = F .pad (coarse_pos , (1 , 0 ), value = - 1 )
942942 fine_pos = F .pad (fine_pos , (1 , 0 ), value = - 1 )
943943
944944 seq_positions = torch .cat ((coarse_pos , fine_pos ), dim = - 1 )
@@ -949,10 +949,43 @@ def forward(
949949
950950 seq_offsets = torch .cat ((coarse_offsets , fine_offsets ), dim = - 1 )
951951
952- pos_mlp_input = torch .stack ((seq_positions , seq_offsets ), dim = - 1 )
952+ pos_mlp_input = torch .stack ((seq_positions .clamp (min = 0 ), seq_offsets ), dim = - 1 )
953+
954+ num_offsets = self .num_fine_quantizers + self .num_coarse_quantizers
955+
956+ # relative positions are always (2 * N - 1), where N is the length of the dimension
957+
958+ rel_seq_len , rel_offsets = map (lambda n : 2 * n - 1 , (max_seq_len , num_offsets ))
959+
960+ # get all relative distances
961+
953962 rel_dist = (rearrange (pos_mlp_input , 'i c -> i 1 c' ) - rearrange (pos_mlp_input , 'j c -> 1 j c' ))
954963
955- attn_bias = self .pos_bias_mlp (rel_dist .float ())
964+ # get all possible relative distances for the attention bias to be computed from the mlp
965+ # which would be - (2 * N - 1) * (2 * Q - 1) - where N = sequence length and Q = total quantizers
966+
967+ rel_seq_len_range = repeat (torch .arange (rel_seq_len , device = device ), 'n -> (n q)' , q = rel_offsets )
968+ rel_offset_range = repeat (torch .arange (rel_offsets , device = device ), 'q -> (n q)' , n = rel_seq_len )
969+
970+ mlp_inputs = torch .stack ((rel_seq_len_range , rel_offset_range ), dim = - 1 )
971+
972+ # implicitly parameterized relative distances, by sequence and quantizer positions
973+
974+ attn_bias = self .pos_bias_mlp (mlp_inputs .float ())
975+
976+ # translate coordinates of (rel_seq_pos, rel_quantizer_offset) -> positive index to select from attn bias
977+
978+ rel_dist_seq_pos , rel_dist_seq_offset = rel_dist .unbind (dim = - 1 )
979+
980+ rel_dist_seq_pos += max_seq_len - 1
981+ rel_dist_seq_offset += num_offsets - 1
982+
983+ rel_dist_indices = rel_dist_seq_pos * rel_offsets + rel_dist_seq_offset
984+
985+ # select the relative positional attention bias outputted by the MLP
986+ # savings go from (N * Q) ^ 2 -> ~ (4 * N * Q)
987+
988+ attn_bias = attn_bias [rel_dist_indices ]
956989
957990 attn_bias = rearrange (attn_bias , '... h -> h ...' )
958991
0 commit comments