Skip to content

Commit 3c63957

Browse files
committed
fix the memory issue with the 2d relative attention bias (seq and quantizer positions) in the fine transformer
1 parent ff60868 commit 3c63957

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def forward(
938938
coarse_pos = repeat(coarse_pos, 'n -> (n q)', q = self.num_coarse_quantizers)[:coarse_length]
939939
fine_pos = repeat(fine_pos, 'n -> (n q)', q = self.num_fine_quantizers)[:fine_length]
940940

941-
coarse_pos = F.pad(coarse_pos, (1, 0), value = -1) # -1 for start token
941+
coarse_pos = F.pad(coarse_pos, (1, 0), value = -1)
942942
fine_pos = F.pad(fine_pos, (1, 0), value = -1)
943943

944944
seq_positions = torch.cat((coarse_pos, fine_pos), dim = -1)
@@ -949,10 +949,43 @@ def forward(
949949

950950
seq_offsets = torch.cat((coarse_offsets, fine_offsets), dim = -1)
951951

952-
pos_mlp_input = torch.stack((seq_positions, seq_offsets), dim = -1)
952+
pos_mlp_input = torch.stack((seq_positions.clamp(min = 0), seq_offsets), dim = -1)
953+
954+
num_offsets = self.num_fine_quantizers + self.num_coarse_quantizers
955+
956+
# relative positions are always (2 * N - 1), where N is the length of the dimension
957+
958+
rel_seq_len, rel_offsets = map(lambda n: 2 * n - 1, (max_seq_len, num_offsets))
959+
960+
# get all relative distances
961+
953962
rel_dist = (rearrange(pos_mlp_input, 'i c -> i 1 c') - rearrange(pos_mlp_input, 'j c -> 1 j c'))
954963

955-
attn_bias = self.pos_bias_mlp(rel_dist.float())
964+
# get all possible relative distances for the attention bias to be computed from the mlp
965+
# which would be - (2 * N - 1) * (2 * Q - 1) - where N = sequence length and Q = total quantizers
966+
967+
rel_seq_len_range = repeat(torch.arange(rel_seq_len, device = device), 'n -> (n q)', q = rel_offsets)
968+
rel_offset_range = repeat(torch.arange(rel_offsets, device = device), 'q -> (n q)', n = rel_seq_len)
969+
970+
mlp_inputs = torch.stack((rel_seq_len_range, rel_offset_range), dim = -1)
971+
972+
# implicitly parameterized relative distances, by sequence and quantizer positions
973+
974+
attn_bias = self.pos_bias_mlp(mlp_inputs.float())
975+
976+
# translate coordinates of (rel_seq_pos, rel_quantizer_offset) -> positive index to select from attn bias
977+
978+
rel_dist_seq_pos, rel_dist_seq_offset = rel_dist.unbind(dim = -1)
979+
980+
rel_dist_seq_pos += max_seq_len - 1
981+
rel_dist_seq_offset += num_offsets - 1
982+
983+
rel_dist_indices = rel_dist_seq_pos * rel_offsets + rel_dist_seq_offset
984+
985+
# select the relative positional attention bias outputted by the MLP
986+
# savings go from (N * Q) ^ 2 -> ~ (4 * N * Q)
987+
988+
attn_bias = attn_bias[rel_dist_indices]
956989

957990
attn_bias = rearrange(attn_bias, '... h -> h ...')
958991

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.23.3'
1+
__version__ = '0.23.5'

0 commit comments

Comments
 (0)