Skip to content

Commit 4dcb592

Browse files
author
sangchengmeng
committed
support-whisper-longaudio
1 parent fbd13bb commit 4dcb592

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

lightllm/models/internvl/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def get_image_token_length(self, img: ImageItem):
8181

8282
def get_audio_token_length(self, audio: AudioItem):
8383
L = audio.audio_length
84-
L = L if L <= 480000 else 480000 # max_length < 30s
8584
mel_len = L // 160
8685
dilation = 1
8786
L_in = mel_len

lightllm/models/whisper/whisper_audio.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,12 @@ def forward(self, audio_values, audio_lens_after_cnn):
162162
return x
163163

164164
def encode(self, audio_items: List[AudioItem]):
165+
# 每个元素是一个chunk
165166
batch_audios = []
166-
batch_audio_lens = np.zeros(len(audio_items), dtype=np.int32)
167+
batch_audio_lens = []
167168
uuids = []
169+
# 记录每个chunk属于哪个audio_items下标
170+
chunk_owner_index = []
168171
for i, item in enumerate(audio_items):
169172
if isinstance(item, AudioItem):
170173
uuids.append(item.uuid)
@@ -180,8 +183,25 @@ def encode(self, audio_items: List[AudioItem]):
180183
if audio.shape[0] < MIN_AUDIO_LEN:
181184
audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode="constant", constant_values=0.0)
182185

183-
batch_audio_lens[i] = min(audio.shape[0], self.max_length)
184-
batch_audios.append(audio)
186+
if audio.shape[0] > self.max_length:
187+
start = 0
188+
while start < audio.shape[0]:
189+
end = min(start + self.max_length, audio.shape[0])
190+
chunk = audio[start:end]
191+
192+
if chunk.shape[0] < MIN_AUDIO_LEN:
193+
chunk = np.pad(chunk, (0, MIN_AUDIO_LEN - chunk.shape[0]), mode="constant", constant_values=0.0)
194+
batch_audios.append(chunk)
195+
batch_audio_lens.append(min(chunk.shape[0], self.max_length))
196+
chunk_owner_index.append(i)
197+
198+
start = end
199+
else:
200+
batch_audio_lens.append(min(audio.shape[0], self.max_length))
201+
batch_audios.append(audio)
202+
chunk_owner_index.append(i)
203+
204+
batch_audio_lens = np.array(batch_audio_lens, dtype=np.int32)
185205

186206
audios, audio_lens_after_cnn = self.audio_processor(
187207
batch_audios, batch_audio_lens, sampling_rate=16000, return_tensors="pt"
@@ -190,13 +210,28 @@ def encode(self, audio_items: List[AudioItem]):
190210
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191211
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192212

213+
num_audios = len(audio_items)
214+
per_audio_embeds = [[] for _ in range(num_audios)]
215+
216+
for chunk_idx, owner in enumerate(chunk_owner_index):
217+
token_len = int(audio_token_num[chunk_idx])
218+
if token_len <= 0:
219+
continue
220+
per_audio_embeds[owner].append(audios[chunk_idx][:token_len])
221+
193222
ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
194223
ids_to_set = []
195224
for i, ready in enumerate(ready_audio):
196-
if not ready:
197-
uid = uuids[i]
198-
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
199-
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
200-
ids_to_set.append(uid)
225+
if ready:
226+
continue
227+
228+
uid = uuids[i]
229+
230+
# 拼接该 audio 的所有 chunk embedding
231+
cur_embed = torch.cat(per_audio_embeds[i], dim=0)
232+
cur_embed_bytes = tensor2bytes(cur_embed)
233+
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
234+
ids_to_set.append(uid)
235+
201236
if ids_to_set:
202237
self.cache_client.root.set_items_embed(ids=ids_to_set)

0 commit comments

Comments
 (0)