Skip to content

Commit 1d0cb78

Browse files
author
sangchengmeng
committed
fix-get-audio-length
1 parent 4dcb592 commit 1d0cb78

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

lightllm/models/internvl/model.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
2121
from lightllm.models.vit import get_image_patch_func
22+
from lightllm.models.whisper.defaults import MIN_AUDIO_LEN
2223

2324
IMG_START_TOKEN = "<img>"
2425
IMG_END_TOKEN = "</img>"
@@ -47,6 +48,9 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4748
self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag)
4849
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4950

51+
self.audio_min_length = MIN_AUDIO_LEN
52+
self.audio_max_length = 16000 * 30
53+
5054
def init_imageitem_extral_params(
5155
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
5256
):
@@ -81,15 +85,35 @@ def get_image_token_length(self, img: ImageItem):
8185

8286
def get_audio_token_length(self, audio: AudioItem):
8387
L = audio.audio_length
84-
mel_len = L // 160
85-
dilation = 1
86-
L_in = mel_len
87-
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
88-
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
89-
L_out = 1 + L_out // stride
90-
L_in = L_out
91-
audio_len_after_cnn = L_out
92-
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1
88+
audio_token_num = 0
89+
chunk_lens = []
90+
if L <= self.audio_max_length:
91+
cur_len = L
92+
if cur_len < self.audio_min_length:
93+
cur_len = MIN_AUDIO_LEN
94+
chunk_lens.append(cur_len)
95+
else:
96+
start = 0
97+
while start < L:
98+
end = min(start + self.audio_max_length, L)
99+
cur_len = end - start
100+
101+
if cur_len < MIN_AUDIO_LEN:
102+
cur_len = MIN_AUDIO_LEN
103+
104+
chunk_lens.append(cur_len)
105+
start = end
106+
for chunk_len in chunk_lens:
107+
mel_len = chunk_len // 160
108+
dilation = 1
109+
L_in = mel_len
110+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
111+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
112+
L_out = 1 + L_out // stride
113+
L_in = L_out
114+
audio_len_after_cnn = L_out
115+
chunk_token_num = (audio_len_after_cnn - 2) // 2 + 1
116+
audio_token_num += int(chunk_token_num)
93117
return audio_token_num
94118

95119
# only change the impl of the encode func:

0 commit comments

Comments
 (0)