Skip to content

Commit 0ec7b57

Browse files
achartierQiJune
andauthored
chore: Handle qwen2audio inputs ids expansion during processing (NVIDIA#3080)
* Handle qwen2audio inputs ids expansion during processing Signed-off-by: Aurelien Chartier <[email protected]> * remove more dead code Signed-off-by: Aurelien Chartier <[email protected]> * fix yapf Signed-off-by: Aurelien Chartier <[email protected]> --------- Signed-off-by: Aurelien Chartier <[email protected]> Co-authored-by: QI JUN <[email protected]>
1 parent 3c7cb66 commit 0ec7b57

File tree

2 files changed

+7
-42
lines changed

2 files changed

+7
-42
lines changed

examples/qwen2audio/run.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,9 @@ def generate_for_qwen_audio(
328328
input_tokens,
329329
args,
330330
prompt_table=None,
331-
tasks=None,
332-
task_vocab_size=None,
333331
extra_ids=None,
334332
run_time=1,
335333
):
336-
input_ids = None
337-
input_lengths = None
338334
input_ids = torch.as_tensor(input_tokens,
339335
device=self.gpu_device,
340336
dtype=torch.int32)
@@ -398,8 +394,7 @@ def qwen_infer(self,
398394
stream,
399395
history=None,
400396
past_audio_features=None,
401-
run_time=1,
402-
gpu_id=0):
397+
run_time=1):
403398
assert input_text, "input_text must be provided"
404399
assert torch.cuda.is_available(), "no gpu available"
405400
# preprocess on CPU maybe faster
@@ -464,9 +459,7 @@ def qwen_infer(self,
464459
# 1. Create a mask to know where special audio tokens are
465460
special_audio_token_mask = input_ids == self.config.audio_token_index
466461
special_audio_token_num = special_audio_token_mask.sum().item()
467-
if past_audio_features is None:
468-
assert special_audio_token_num == num_audios, f'special_audio_token_num {special_audio_token_num} should be equal to num_audios {num_audios}'
469-
else:
462+
if past_audio_features is not None:
470463
assert isinstance(past_audio_features,
471464
list), f'past_audio_features should be a list'
472465
assert (
@@ -497,40 +490,16 @@ def qwen_infer(self,
497490
batch_indices, non_audio_indices = torch.where(
498491
input_ids != self.config.audio_token_index)
499492

500-
# 2. Compute the positions where text should be written
501-
# Calculate new positions for text tokens in merged audio-text sequence.
502-
# `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
503-
# `torch.cumsum` computes how each audio token shifts subsequent text token positions.
504-
token_placeholder_num = torch.zeros_like(input_ids, device=device)
505-
token_placeholder_num[
506-
special_audio_token_mask] = num_audio_tokens.long() - 1
507-
token_placeholder_num = token_placeholder_num + 1
508-
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
509-
max_token_num = token_placeholder_num.sum(-1).max()
510-
text_to_overwrite = new_token_positions[batch_indices,
511-
non_audio_indices]
512-
513-
# 3. Create the final_input_ids, already padded to the maximum position
514-
final_input_ids = torch.full((batch_size, max_token_num),
515-
self.config.audio_token_index,
516-
dtype=input_ids.dtype,
517-
device=device)
493+
# 2. Fill the final input ids based on the mask.
494+
batch_indices, audio_indices = torch.where(
495+
input_ids == self.config.audio_token_index)
518496

519-
# 4. Fill the final_input_ids based on the mask. If we have ["hey" "<audio>", "how", "are"]
520-
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
521-
final_input_ids[batch_indices,
522-
text_to_overwrite] = input_ids[batch_indices,
523-
non_audio_indices]
524497
vocab_size = self.config.vocab_size
525498
fake_prompt_id = torch.arange(vocab_size,
526499
vocab_size + num_audio_tokens.sum(),
527500
device=device)
528-
batch_indices, audio_indices = torch.where(
529-
final_input_ids == self.config.audio_token_index)
530-
final_input_ids[batch_indices, audio_indices] = fake_prompt_id
531501

532-
input_ids = final_input_ids.contiguous().to(dtype=torch.int32,
533-
device=self.gpu_device)
502+
input_ids[batch_indices, audio_indices] = fake_prompt_id
534503
input_lengths = torch.tensor(input_ids.size(1),
535504
dtype=torch.int32,
536505
device=self.gpu_device)
@@ -568,8 +537,7 @@ def qwen_infer(self,
568537

569538
# print(f"extra_ids: {extra_ids}")
570539
output_ids, Qwen_time = self.generate_for_qwen_audio(
571-
input_ids, args, prompt_table, tasks, task_vocab_size, extra_ids,
572-
run_time)
540+
input_ids, args, prompt_table, extra_ids, run_time)
573541

574542
runtime_rank = tensorrt_llm.mpi_rank()
575543
input_lengths = torch.tensor([input_ids.size(1)],

tests/integration/defs/examples/test_qwen2audio.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def test_llm_qwen2audio_single_gpu(qwen2audio_example_root, llm_qwen_model_root,
4141
"Build & run qwen2audio on 1 gpu."
4242
workspace = llm_venv.get_working_directory()
4343

44-
# https://nvbugs/5136784
45-
llm_venv.run_cmd(['-m', 'pip', 'install', 'transformers==4.47.1'])
46-
4744
print("Generate audio engine...")
4845
audio_engine_dir = f"{engine_dir}/audio"
4946
audio_cmd = [

0 commit comments

Comments
 (0)