Skip to content

Commit b29141f

Browse files
committed
use 2d dynamic positional bias for fine transformer, to try to improve training at greater number of fine quantizers
1 parent 9ba1b04 commit b29141f

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -822,8 +822,7 @@ def __init__(
822822

823823
pos_bias_mlp_dim = dim // 2
824824
self.pos_bias_mlp = nn.Sequential(
825-
Rearrange('... -> ... 1'),
826-
nn.Linear(1, pos_bias_mlp_dim),
825+
nn.Linear(2, pos_bias_mlp_dim),
827826
nn.SiLU(),
828827
nn.Linear(pos_bias_mlp_dim, pos_bias_mlp_dim),
829828
nn.SiLU(),
@@ -895,18 +894,18 @@ def forward(
895894
b, n = coarse_token_ids.shape
896895

897896
coarse_length = coarse_token_ids.shape[-1]
898-
coarse_offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
897+
coarse_offsets = torch.arange(self.num_coarse_quantizers, device = device)
899898
coarse_seq_length = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers)
900-
coarse_offsets = repeat(coarse_offsets, 'q -> 1 (n q)', n = coarse_seq_length)
901-
coarse_offsets = coarse_offsets[:, :coarse_length]
902-
coarse_token_ids = coarse_token_ids + coarse_offsets
899+
coarse_offsets = repeat(coarse_offsets, 'q -> (n q)', n = coarse_seq_length)
900+
coarse_offsets = coarse_offsets[:coarse_length]
901+
coarse_token_ids = coarse_token_ids + rearrange(coarse_offsets, '... -> 1 ...') * self.codebook_size
903902

904903
fine_length = fine_token_ids.shape[-1]
905-
fine_offsets = self.codebook_size * torch.arange(self.num_fine_quantizers, device = device)
904+
fine_offsets = torch.arange(self.num_fine_quantizers, device = device)
906905
fine_seq_length = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers)
907-
fine_offsets = repeat(fine_offsets, 'q -> 1 (n q)', n = fine_seq_length)
908-
fine_offsets = fine_offsets[:, :fine_length]
909-
fine_token_ids = fine_token_ids + fine_offsets
906+
fine_offsets = repeat(fine_offsets, 'q -> (n q)', n = fine_seq_length)
907+
fine_offsets = fine_offsets[:fine_length]
908+
fine_token_ids = fine_token_ids + rearrange(fine_offsets, '... -> 1 ...') * self.codebook_size
910909

911910
coarse_tokens = self.coarse_embedding(coarse_token_ids)
912911
fine_tokens = self.fine_embedding(fine_token_ids)
@@ -944,13 +943,18 @@ def forward(
944943

945944
seq_positions = torch.cat((coarse_pos, fine_pos), dim = -1)
946945

947-
rel_dist = (rearrange(seq_positions, 'i -> i 1') - rearrange(seq_positions, 'j -> 1 j'))
948-
rel_dist = rel_dist + max_seq_len # offset so all positive indices
946+
coarse_offsets = F.pad(coarse_offsets, (1, 0), value = 0)
947+
fine_offsets = fine_offsets + self.num_coarse_quantizers
948+
fine_offsets = F.pad(fine_offsets, (1, 0), value = 0)
949949

950-
mlp_inp = torch.arange(-max_seq_len, max_seq_len + 1, device = device).float()
951-
attn_bias = self.pos_bias_mlp(mlp_inp)
950+
seq_offsets = torch.cat((coarse_offsets, fine_offsets), dim = -1)
952951

953-
attn_bias = rearrange(attn_bias[rel_dist], '... h -> h ...')
952+
pos_mlp_input = torch.stack((seq_positions, seq_offsets), dim = -1)
953+
rel_dist = (rearrange(pos_mlp_input, 'i c -> i 1 c') - rearrange(pos_mlp_input, 'j c -> 1 j c'))
954+
955+
attn_bias = self.pos_bias_mlp(rel_dist.float())
956+
957+
attn_bias = rearrange(attn_bias, '... h -> h ...')
954958

955959
# need to make sure start token has a custom positional bias
956960

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.22.3'
1+
__version__ = '0.23.2'

0 commit comments

Comments
 (0)