@@ -328,13 +328,9 @@ def generate_for_qwen_audio(
328328 input_tokens ,
329329 args ,
330330 prompt_table = None ,
331- tasks = None ,
332- task_vocab_size = None ,
333331 extra_ids = None ,
334332 run_time = 1 ,
335333 ):
336- input_ids = None
337- input_lengths = None
338334 input_ids = torch .as_tensor (input_tokens ,
339335 device = self .gpu_device ,
340336 dtype = torch .int32 )
@@ -398,8 +394,7 @@ def qwen_infer(self,
398394 stream ,
399395 history = None ,
400396 past_audio_features = None ,
401- run_time = 1 ,
402- gpu_id = 0 ):
397+ run_time = 1 ):
403398 assert input_text , "input_text must be provided"
404399 assert torch .cuda .is_available (), "no gpu available"
405400 # preprocess on CPU maybe faster
@@ -464,9 +459,7 @@ def qwen_infer(self,
464459 # 1. Create a mask to know where special audio tokens are
465460 special_audio_token_mask = input_ids == self .config .audio_token_index
466461 special_audio_token_num = special_audio_token_mask .sum ().item ()
467- if past_audio_features is None :
468- assert special_audio_token_num == num_audios , f'special_audio_token_num { special_audio_token_num } should be equal to num_audios { num_audios } '
469- else :
462+ if past_audio_features is not None :
470463 assert isinstance (past_audio_features ,
471464 list ), f'past_audio_features should be a list'
472465 assert (
@@ -497,40 +490,16 @@ def qwen_infer(self,
497490 batch_indices , non_audio_indices = torch .where (
498491 input_ids != self .config .audio_token_index )
499492
500- # 2. Compute the positions where text should be written
501- # Calculate new positions for text tokens in merged audio-text sequence.
502- # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
503- # `torch.cumsum` computes how each audio token shifts subsequent text token positions.
504- token_placeholder_num = torch .zeros_like (input_ids , device = device )
505- token_placeholder_num [
506- special_audio_token_mask ] = num_audio_tokens .long () - 1
507- token_placeholder_num = token_placeholder_num + 1
508- new_token_positions = torch .cumsum (token_placeholder_num , - 1 ) - 1
509- max_token_num = token_placeholder_num .sum (- 1 ).max ()
510- text_to_overwrite = new_token_positions [batch_indices ,
511- non_audio_indices ]
512-
513- # 3. Create the final_input_ids, already padded to the maximum position
514- final_input_ids = torch .full ((batch_size , max_token_num ),
515- self .config .audio_token_index ,
516- dtype = input_ids .dtype ,
517- device = device )
493+ # 2. Fill the final input ids based on the mask.
494+ batch_indices , audio_indices = torch .where (
495+ input_ids == self .config .audio_token_index )
518496
519- # 4. Fill the final_input_ids based on the mask. If we have ["hey" "<audio>", "how", "are"]
520- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
521- final_input_ids [batch_indices ,
522- text_to_overwrite ] = input_ids [batch_indices ,
523- non_audio_indices ]
524497 vocab_size = self .config .vocab_size
525498 fake_prompt_id = torch .arange (vocab_size ,
526499 vocab_size + num_audio_tokens .sum (),
527500 device = device )
528- batch_indices , audio_indices = torch .where (
529- final_input_ids == self .config .audio_token_index )
530- final_input_ids [batch_indices , audio_indices ] = fake_prompt_id
531501
532- input_ids = final_input_ids .contiguous ().to (dtype = torch .int32 ,
533- device = self .gpu_device )
502+ input_ids [batch_indices , audio_indices ] = fake_prompt_id
534503 input_lengths = torch .tensor (input_ids .size (1 ),
535504 dtype = torch .int32 ,
536505 device = self .gpu_device )
@@ -568,8 +537,7 @@ def qwen_infer(self,
568537
569538 # print(f"extra_ids: {extra_ids}")
570539 output_ids , Qwen_time = self .generate_for_qwen_audio (
571- input_ids , args , prompt_table , tasks , task_vocab_size , extra_ids ,
572- run_time )
540+ input_ids , args , prompt_table , extra_ids , run_time )
573541
574542 runtime_rank = tensorrt_llm .mpi_rank ()
575543 input_lengths = torch .tensor ([input_ids .size (1 )],
0 commit comments