|
19 | 19 | ) |
20 | 20 | from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight |
21 | 21 | from lightllm.models.vit import get_image_patch_func |
| 22 | +from lightllm.models.whisper.defaults import MIN_AUDIO_LEN |
22 | 23 |
|
23 | 24 | IMG_START_TOKEN = "<img>" |
24 | 25 | IMG_END_TOKEN = "</img>" |
@@ -47,6 +48,9 @@ def __init__(self, tokenizer, model_cfg, **kwargs): |
47 | 48 | self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag) |
48 | 49 | self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"]) |
49 | 50 |
|
| 51 | + self.audio_min_length = MIN_AUDIO_LEN |
| 52 | + self.audio_max_length = 16000 * 30 |
| 53 | + |
50 | 54 | def init_imageitem_extral_params( |
51 | 55 | self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams |
52 | 56 | ): |
@@ -81,15 +85,35 @@ def get_image_token_length(self, img: ImageItem): |
81 | 85 |
|
82 | 86 | def get_audio_token_length(self, audio: AudioItem): |
83 | 87 | L = audio.audio_length |
84 | | - mel_len = L // 160 |
85 | | - dilation = 1 |
86 | | - L_in = mel_len |
87 | | - for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): |
88 | | - L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 |
89 | | - L_out = 1 + L_out // stride |
90 | | - L_in = L_out |
91 | | - audio_len_after_cnn = L_out |
92 | | - audio_token_num = (audio_len_after_cnn - 2) // 2 + 1 |
| 88 | + audio_token_num = 0 |
| 89 | + chunk_lens = [] |
| 90 | + if L <= self.audio_max_length: |
| 91 | + cur_len = L |
| 92 | + if cur_len < self.audio_min_length: |
| 93 | + cur_len = MIN_AUDIO_LEN |
| 94 | + chunk_lens.append(cur_len) |
| 95 | + else: |
| 96 | + start = 0 |
| 97 | + while start < L: |
| 98 | + end = min(start + self.audio_max_length, L) |
| 99 | + cur_len = end - start |
| 100 | + |
| 101 | + if cur_len < MIN_AUDIO_LEN: |
| 102 | + cur_len = MIN_AUDIO_LEN |
| 103 | + |
| 104 | + chunk_lens.append(cur_len) |
| 105 | + start = end |
| 106 | + for chunk_len in chunk_lens: |
| 107 | + mel_len = chunk_len // 160 |
| 108 | + dilation = 1 |
| 109 | + L_in = mel_len |
| 110 | + for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): |
| 111 | + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 |
| 112 | + L_out = 1 + L_out // stride |
| 113 | + L_in = L_out |
| 114 | + audio_len_after_cnn = L_out |
| 115 | + chunk_token_num = (audio_len_after_cnn - 2) // 2 + 1 |
| 116 | + audio_token_num += int(chunk_token_num) |
93 | 117 | return audio_token_num |
94 | 118 |
|
95 | 119 | # only change the impl of the encode func: |
|
0 commit comments