@@ -822,8 +822,7 @@ def __init__(
822822
823823 pos_bias_mlp_dim = dim // 2
824824 self .pos_bias_mlp = nn .Sequential (
825- Rearrange ('... -> ... 1' ),
826- nn .Linear (1 , pos_bias_mlp_dim ),
825+ nn .Linear (2 , pos_bias_mlp_dim ),
827826 nn .SiLU (),
828827 nn .Linear (pos_bias_mlp_dim , pos_bias_mlp_dim ),
829828 nn .SiLU (),
@@ -895,18 +894,18 @@ def forward(
895894 b , n = coarse_token_ids .shape
896895
897896 coarse_length = coarse_token_ids .shape [- 1 ]
898- coarse_offsets = self . codebook_size * torch .arange (self .num_coarse_quantizers , device = device )
897+ coarse_offsets = torch .arange (self .num_coarse_quantizers , device = device )
899898 coarse_seq_length = ceil_div (coarse_token_ids .shape [- 1 ], self .num_coarse_quantizers )
900- coarse_offsets = repeat (coarse_offsets , 'q -> 1 (n q)' , n = coarse_seq_length )
901- coarse_offsets = coarse_offsets [:, : coarse_length ]
902- coarse_token_ids = coarse_token_ids + coarse_offsets
899+ coarse_offsets = repeat (coarse_offsets , 'q -> (n q)' , n = coarse_seq_length )
900+ coarse_offsets = coarse_offsets [:coarse_length ]
901+ coarse_token_ids = coarse_token_ids + rearrange ( coarse_offsets , '... -> 1 ...' ) * self . codebook_size
903902
904903 fine_length = fine_token_ids .shape [- 1 ]
905- fine_offsets = self . codebook_size * torch .arange (self .num_fine_quantizers , device = device )
904+ fine_offsets = torch .arange (self .num_fine_quantizers , device = device )
906905 fine_seq_length = ceil_div (fine_token_ids .shape [- 1 ], self .num_fine_quantizers )
907- fine_offsets = repeat (fine_offsets , 'q -> 1 (n q)' , n = fine_seq_length )
908- fine_offsets = fine_offsets [:, : fine_length ]
909- fine_token_ids = fine_token_ids + fine_offsets
906+ fine_offsets = repeat (fine_offsets , 'q -> (n q)' , n = fine_seq_length )
907+ fine_offsets = fine_offsets [:fine_length ]
908+ fine_token_ids = fine_token_ids + rearrange ( fine_offsets , '... -> 1 ...' ) * self . codebook_size
910909
911910 coarse_tokens = self .coarse_embedding (coarse_token_ids )
912911 fine_tokens = self .fine_embedding (fine_token_ids )
@@ -944,13 +943,18 @@ def forward(
944943
945944 seq_positions = torch .cat ((coarse_pos , fine_pos ), dim = - 1 )
946945
947- rel_dist = (rearrange (seq_positions , 'i -> i 1' ) - rearrange (seq_positions , 'j -> 1 j' ))
948- rel_dist = rel_dist + max_seq_len # offset so all positive indices
946+ coarse_offsets = F .pad (coarse_offsets , (1 , 0 ), value = 0 )
947+ fine_offsets = fine_offsets + self .num_coarse_quantizers
948+ fine_offsets = F .pad (fine_offsets , (1 , 0 ), value = 0 )
949949
950- mlp_inp = torch .arange (- max_seq_len , max_seq_len + 1 , device = device ).float ()
951- attn_bias = self .pos_bias_mlp (mlp_inp )
950+ seq_offsets = torch .cat ((coarse_offsets , fine_offsets ), dim = - 1 )
952951
953- attn_bias = rearrange (attn_bias [rel_dist ], '... h -> h ...' )
952+ pos_mlp_input = torch .stack ((seq_positions , seq_offsets ), dim = - 1 )
953+ rel_dist = (rearrange (pos_mlp_input , 'i c -> i 1 c' ) - rearrange (pos_mlp_input , 'j c -> 1 j c' ))
954+
955+ attn_bias = self .pos_bias_mlp (rel_dist .float ())
956+
957+ attn_bias = rearrange (attn_bias , '... h -> h ...' )
954958
955959 # need to make sure start token has a custom positional bias
956960
0 commit comments