44tasks, then runs Omni generation and saves output wav files.
55"""
66
7+ import logging
78import os
8- from typing import NamedTuple
9+ from typing import Any , NamedTuple
910
1011import soundfile as sf
12+ import torch
1113
1214os .environ ["VLLM_WORKER_MULTIPROC_METHOD" ] = "spawn"
1315
14- from vllm import SamplingParams
1516from vllm .utils .argparse_utils import FlexibleArgumentParser
1617
1718from vllm_omni import Omni
1819
20+ logger = logging .getLogger (__name__ )
21+
1922
2023class QueryResult (NamedTuple ):
2124 """Container for a prepared Omni request."""
@@ -24,6 +27,44 @@ class QueryResult(NamedTuple):
2427 model_name : str
2528
2629
30+ def _estimate_prompt_len (
31+ additional_information : dict [str , Any ],
32+ model_name : str ,
33+ _cache : dict [str , Any ] = {},
34+ ) -> int :
35+ """Estimate prompt_token_ids placeholder length for the Talker stage.
36+
37+ The AR Talker replaces all input embeddings via ``preprocess``, so the
38+ placeholder values are irrelevant but the **length** must match the
39+ embeddings that ``preprocess`` will produce.
40+ """
41+ try :
42+ from vllm_omni .model_executor .models .qwen3_tts .configuration_qwen3_tts import Qwen3TTSConfig
43+ from vllm_omni .model_executor .models .qwen3_tts .qwen3_tts_talker import (
44+ Qwen3TTSTalkerForConditionalGeneration ,
45+ )
46+
47+ if model_name not in _cache :
48+ from transformers import AutoTokenizer
49+
50+ tok = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True , padding_side = "left" )
51+ cfg = Qwen3TTSConfig .from_pretrained (model_name , trust_remote_code = True )
52+ _cache [model_name ] = (tok , getattr (cfg , "talker_config" , None ))
53+
54+ tok , tcfg = _cache [model_name ]
55+ task_type = (additional_information .get ("task_type" ) or ["CustomVoice" ])[0 ]
56+ return Qwen3TTSTalkerForConditionalGeneration .estimate_prompt_len_from_additional_information (
57+ additional_information = additional_information ,
58+ task_type = task_type ,
59+ tokenize_prompt = lambda t : tok (t , padding = False )["input_ids" ],
60+ codec_language_id = getattr (tcfg , "codec_language_id" , None ),
61+ spk_is_dialect = getattr (tcfg , "spk_is_dialect" , None ),
62+ )
63+ except Exception as exc :
64+ logger .warning ("Failed to estimate prompt length, using fallback 2048: %s" , exc )
65+ return 2048
66+
67+
2768def get_custom_voice_query (use_batch_sample : bool = False ) -> QueryResult :
2869 """Build CustomVoice sample inputs.
2970
@@ -34,47 +75,48 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
3475 QueryResult with Omni inputs and the CustomVoice model path.
3576 """
3677 task_type = "CustomVoice"
78+ model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
3779 if use_batch_sample :
3880 texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。" , "She said she would be here by noon." ]
3981 instructs = ["" , "Very happy." ]
4082 languages = ["Chinese" , "English" ]
4183 speakers = ["Vivian" , "Ryan" ]
4284 inputs = []
4385 for text , instruct , language , speaker in zip (texts , instructs , languages , speakers ):
44- prompt = f"<|im_start|>assistant\n { text } <|im_end|>\n <|im_start|>assistant\n "
86+ additional_information = {
87+ "task_type" : [task_type ],
88+ "text" : [text ],
89+ "instruct" : [instruct ],
90+ "language" : [language ],
91+ "speaker" : [speaker ],
92+ "max_new_tokens" : [2048 ],
93+ }
4594 inputs .append (
4695 {
47- "prompt" : prompt ,
48- "additional_information" : {
49- "task_type" : [task_type ],
50- "text" : [text ],
51- "instruct" : [instruct ],
52- "language" : [language ],
53- "speaker" : [speaker ],
54- "max_new_tokens" : [2048 ],
55- },
96+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
97+ "additional_information" : additional_information ,
5698 }
5799 )
58100 else :
59101 text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。"
60102 language = "Chinese"
61103 speaker = "Vivian"
62104 instruct = "用特别愤怒的语气说"
63- prompts = f"<|im_start|>assistant\n { text } <|im_end|>\n <|im_start|>assistant\n "
105+ additional_information = {
106+ "task_type" : [task_type ],
107+ "text" : [text ],
108+ "language" : [language ],
109+ "speaker" : [speaker ],
110+ "instruct" : [instruct ],
111+ "max_new_tokens" : [2048 ],
112+ }
64113 inputs = {
65- "prompt" : prompts ,
66- "additional_information" : {
67- "task_type" : [task_type ],
68- "text" : [text ],
69- "language" : [language ],
70- "speaker" : [speaker ],
71- "instruct" : [instruct ],
72- "max_new_tokens" : [2048 ],
73- },
114+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
115+ "additional_information" : additional_information ,
74116 }
75117 return QueryResult (
76118 inputs = inputs ,
77- model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" ,
119+ model_name = model_name ,
78120 )
79121
80122
@@ -88,6 +130,7 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
88130 QueryResult with Omni inputs and the VoiceDesign model path.
89131 """
90132 task_type = "VoiceDesign"
133+ model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
91134 if use_batch_sample :
92135 texts = [
93136 "哥哥,你回来啦,人家等了你好久好久了,要抱抱!" ,
@@ -100,39 +143,39 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
100143 languages = ["Chinese" , "English" ]
101144 inputs = []
102145 for text , instruct , language in zip (texts , instructs , languages ):
103- prompt = f"<|im_start|>assistant\n { text } <|im_end|>\n <|im_start|>assistant\n "
146+ additional_information = {
147+ "task_type" : [task_type ],
148+ "text" : [text ],
149+ "language" : [language ],
150+ "instruct" : [instruct ],
151+ "max_new_tokens" : [2048 ],
152+ "non_streaming_mode" : [True ],
153+ }
104154 inputs .append (
105155 {
106- "prompt" : prompt ,
107- "additional_information" : {
108- "task_type" : [task_type ],
109- "text" : [text ],
110- "language" : [language ],
111- "instruct" : [instruct ],
112- "max_new_tokens" : [2048 ],
113- "non_streaming_mode" : [True ],
114- },
156+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
157+ "additional_information" : additional_information ,
115158 }
116159 )
117160 else :
118161 text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!"
119162 instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。"
120163 language = "Chinese"
121- prompt = f"<|im_start|>assistant\n { text } <|im_end|>\n <|im_start|>assistant\n "
164+ additional_information = {
165+ "task_type" : [task_type ],
166+ "text" : [text ],
167+ "language" : [language ],
168+ "instruct" : [instruct ],
169+ "max_new_tokens" : [2048 ],
170+ "non_streaming_mode" : [True ],
171+ }
122172 inputs = {
123- "prompt" : prompt ,
124- "additional_information" : {
125- "task_type" : [task_type ],
126- "text" : [text ],
127- "language" : [language ],
128- "instruct" : [instruct ],
129- "max_new_tokens" : [2048 ],
130- "non_streaming_mode" : [True ],
131- },
173+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
174+ "additional_information" : additional_information ,
132175 }
133176 return QueryResult (
134177 inputs = inputs ,
135- model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" ,
178+ model_name = model_name ,
136179 )
137180
138181
@@ -147,6 +190,7 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
147190 QueryResult with Omni inputs and the Base model path.
148191 """
149192 task_type = "Base"
193+ model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
150194 ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
151195 ref_audio_single = ref_audio_path_1
152196 ref_text_single = (
@@ -163,38 +207,38 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
163207 syn_lang_batch = ["Chinese" , "English" ]
164208 inputs = []
165209 for text , language in zip (syn_text_batch , syn_lang_batch ):
166- prompt = f"<|im_start|>assistant\n { text } <|im_end|>\n <|im_start|>assistant\n "
210+ additional_information = {
211+ "task_type" : [task_type ],
212+ "ref_audio" : [ref_audio_single ],
213+ "ref_text" : [ref_text_single ],
214+ "text" : [text ],
215+ "language" : [language ],
216+ "x_vector_only_mode" : [x_vector_only_mode ],
217+ "max_new_tokens" : [2048 ],
218+ }
167219 inputs .append (
168220 {
169- "prompt" : prompt ,
170- "additional_information" : {
171- "task_type" : [task_type ],
172- "ref_audio" : [ref_audio_single ],
173- "ref_text" : [ref_text_single ],
174- "text" : [text ],
175- "language" : [language ],
176- "x_vector_only_mode" : [x_vector_only_mode ],
177- "max_new_tokens" : [2048 ],
178- },
221+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
222+ "additional_information" : additional_information ,
179223 }
180224 )
181225 else :
182- prompt = f"<|im_start|>assistant\n { syn_text_single } <|im_end|>\n <|im_start|>assistant\n "
226+ additional_information = {
227+ "task_type" : [task_type ],
228+ "ref_audio" : [ref_audio_single ],
229+ "ref_text" : [ref_text_single ],
230+ "text" : [syn_text_single ],
231+ "language" : [syn_lang_single ],
232+ "x_vector_only_mode" : [x_vector_only_mode ],
233+ "max_new_tokens" : [2048 ],
234+ }
183235 inputs = {
184- "prompt" : prompt ,
185- "additional_information" : {
186- "task_type" : [task_type ],
187- "ref_audio" : [ref_audio_single ],
188- "ref_text" : [ref_text_single ],
189- "text" : [syn_text_single ],
190- "language" : [syn_lang_single ],
191- "x_vector_only_mode" : [x_vector_only_mode ],
192- "max_new_tokens" : [2048 ],
193- },
236+ "prompt_token_ids" : [0 ] * _estimate_prompt_len (additional_information , model_name ),
237+ "additional_information" : additional_information ,
194238 }
195239 return QueryResult (
196240 inputs = inputs ,
197- model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ,
241+ model_name = model_name ,
198242 )
199243
200244
@@ -223,30 +267,22 @@ def main(args):
223267 stage_init_timeout = args .stage_init_timeout ,
224268 )
225269
226- sampling_params = SamplingParams (
227- temperature = 0.9 ,
228- top_p = 1.0 ,
229- top_k = 50 ,
230- max_tokens = 2048 ,
231- seed = 42 ,
232- detokenize = False ,
233- repetition_penalty = 1.05 ,
234- )
235-
236- sampling_params_list = [
237- sampling_params ,
238- ]
239-
240270 output_dir = args .output_dir if getattr (args , "output_dir" , None ) else args .output_wav
241271 os .makedirs (output_dir , exist_ok = True )
242272
243- omni_generator = omni .generate (query_result .inputs , sampling_params_list )
273+ omni_generator = omni .generate (query_result .inputs , sampling_params_list = None )
244274 for stage_outputs in omni_generator :
245275 for output in stage_outputs .request_output :
246276 request_id = output .request_id
247- audio_tensor = output .outputs [0 ].multimodal_output ["audio" ]
277+ audio_data = output .outputs [0 ].multimodal_output ["audio" ]
278+ # async_chunk mode returns a list of chunks; concatenate them.
279+ if isinstance (audio_data , list ):
280+ audio_tensor = torch .cat (audio_data , dim = - 1 )
281+ else :
282+ audio_tensor = audio_data
248283 output_wav = os .path .join (output_dir , f"output_{ request_id } .wav" )
249- audio_samplerate = output .outputs [0 ].multimodal_output ["sr" ].item ()
284+ sr_val = output .outputs [0 ].multimodal_output ["sr" ]
285+ audio_samplerate = sr_val .item () if hasattr (sr_val , "item" ) else int (sr_val [- 1 ])
250286 # Convert to numpy array and ensure correct format
251287 audio_numpy = audio_tensor .float ().detach ().cpu ().numpy ()
252288
0 commit comments