Skip to content

Commit efbe411

Browse files
[Feature] : Support disaggregated inference pipeline for Qwen3_TTS (vllm-project#1161)
Signed-off-by: Sy03 <1370724210@qq.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 5295f0f commit efbe411

File tree

23 files changed

+2966
-2588
lines changed

23 files changed

+2966
-2588
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ steps:
176176
# type: DirectoryOrCreate
177177

178178
- label: "Omni Model Test"
179-
timeout_in_minutes: 15
179+
timeout_in_minutes: 20
180180
depends_on: image-build
181181
commands:
182182
- export VLLM_LOGGING_LEVEL=DEBUG

examples/offline_inference/qwen3_tts/end2end.py

Lines changed: 120 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
tasks, then runs Omni generation and saves output wav files.
55
"""
66

7+
import logging
78
import os
8-
from typing import NamedTuple
9+
from typing import Any, NamedTuple
910

1011
import soundfile as sf
12+
import torch
1113

1214
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
1315

14-
from vllm import SamplingParams
1516
from vllm.utils.argparse_utils import FlexibleArgumentParser
1617

1718
from vllm_omni import Omni
1819

20+
logger = logging.getLogger(__name__)
21+
1922

2023
class QueryResult(NamedTuple):
2124
"""Container for a prepared Omni request."""
@@ -24,6 +27,44 @@ class QueryResult(NamedTuple):
2427
model_name: str
2528

2629

30+
def _estimate_prompt_len(
31+
additional_information: dict[str, Any],
32+
model_name: str,
33+
_cache: dict[str, Any] = {},
34+
) -> int:
35+
"""Estimate prompt_token_ids placeholder length for the Talker stage.
36+
37+
The AR Talker replaces all input embeddings via ``preprocess``, so the
38+
placeholder values are irrelevant but the **length** must match the
39+
embeddings that ``preprocess`` will produce.
40+
"""
41+
try:
42+
from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
43+
from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
44+
Qwen3TTSTalkerForConditionalGeneration,
45+
)
46+
47+
if model_name not in _cache:
48+
from transformers import AutoTokenizer
49+
50+
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
51+
cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
52+
_cache[model_name] = (tok, getattr(cfg, "talker_config", None))
53+
54+
tok, tcfg = _cache[model_name]
55+
task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
56+
return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
57+
additional_information=additional_information,
58+
task_type=task_type,
59+
tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
60+
codec_language_id=getattr(tcfg, "codec_language_id", None),
61+
spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
62+
)
63+
except Exception as exc:
64+
logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
65+
return 2048
66+
67+
2768
def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
2869
"""Build CustomVoice sample inputs.
2970
@@ -34,47 +75,48 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
3475
QueryResult with Omni inputs and the CustomVoice model path.
3576
"""
3677
task_type = "CustomVoice"
78+
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
3779
if use_batch_sample:
3880
texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."]
3981
instructs = ["", "Very happy."]
4082
languages = ["Chinese", "English"]
4183
speakers = ["Vivian", "Ryan"]
4284
inputs = []
4385
for text, instruct, language, speaker in zip(texts, instructs, languages, speakers):
44-
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
86+
additional_information = {
87+
"task_type": [task_type],
88+
"text": [text],
89+
"instruct": [instruct],
90+
"language": [language],
91+
"speaker": [speaker],
92+
"max_new_tokens": [2048],
93+
}
4594
inputs.append(
4695
{
47-
"prompt": prompt,
48-
"additional_information": {
49-
"task_type": [task_type],
50-
"text": [text],
51-
"instruct": [instruct],
52-
"language": [language],
53-
"speaker": [speaker],
54-
"max_new_tokens": [2048],
55-
},
96+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
97+
"additional_information": additional_information,
5698
}
5799
)
58100
else:
59101
text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。"
60102
language = "Chinese"
61103
speaker = "Vivian"
62104
instruct = "用特别愤怒的语气说"
63-
prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
105+
additional_information = {
106+
"task_type": [task_type],
107+
"text": [text],
108+
"language": [language],
109+
"speaker": [speaker],
110+
"instruct": [instruct],
111+
"max_new_tokens": [2048],
112+
}
64113
inputs = {
65-
"prompt": prompts,
66-
"additional_information": {
67-
"task_type": [task_type],
68-
"text": [text],
69-
"language": [language],
70-
"speaker": [speaker],
71-
"instruct": [instruct],
72-
"max_new_tokens": [2048],
73-
},
114+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
115+
"additional_information": additional_information,
74116
}
75117
return QueryResult(
76118
inputs=inputs,
77-
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
119+
model_name=model_name,
78120
)
79121

80122

@@ -88,6 +130,7 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
88130
QueryResult with Omni inputs and the VoiceDesign model path.
89131
"""
90132
task_type = "VoiceDesign"
133+
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
91134
if use_batch_sample:
92135
texts = [
93136
"哥哥,你回来啦,人家等了你好久好久了,要抱抱!",
@@ -100,39 +143,39 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
100143
languages = ["Chinese", "English"]
101144
inputs = []
102145
for text, instruct, language in zip(texts, instructs, languages):
103-
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
146+
additional_information = {
147+
"task_type": [task_type],
148+
"text": [text],
149+
"language": [language],
150+
"instruct": [instruct],
151+
"max_new_tokens": [2048],
152+
"non_streaming_mode": [True],
153+
}
104154
inputs.append(
105155
{
106-
"prompt": prompt,
107-
"additional_information": {
108-
"task_type": [task_type],
109-
"text": [text],
110-
"language": [language],
111-
"instruct": [instruct],
112-
"max_new_tokens": [2048],
113-
"non_streaming_mode": [True],
114-
},
156+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
157+
"additional_information": additional_information,
115158
}
116159
)
117160
else:
118161
text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!"
119162
instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。"
120163
language = "Chinese"
121-
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
164+
additional_information = {
165+
"task_type": [task_type],
166+
"text": [text],
167+
"language": [language],
168+
"instruct": [instruct],
169+
"max_new_tokens": [2048],
170+
"non_streaming_mode": [True],
171+
}
122172
inputs = {
123-
"prompt": prompt,
124-
"additional_information": {
125-
"task_type": [task_type],
126-
"text": [text],
127-
"language": [language],
128-
"instruct": [instruct],
129-
"max_new_tokens": [2048],
130-
"non_streaming_mode": [True],
131-
},
173+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
174+
"additional_information": additional_information,
132175
}
133176
return QueryResult(
134177
inputs=inputs,
135-
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
178+
model_name=model_name,
136179
)
137180

138181

@@ -147,6 +190,7 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
147190
QueryResult with Omni inputs and the Base model path.
148191
"""
149192
task_type = "Base"
193+
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
150194
ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
151195
ref_audio_single = ref_audio_path_1
152196
ref_text_single = (
@@ -163,38 +207,38 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
163207
syn_lang_batch = ["Chinese", "English"]
164208
inputs = []
165209
for text, language in zip(syn_text_batch, syn_lang_batch):
166-
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
210+
additional_information = {
211+
"task_type": [task_type],
212+
"ref_audio": [ref_audio_single],
213+
"ref_text": [ref_text_single],
214+
"text": [text],
215+
"language": [language],
216+
"x_vector_only_mode": [x_vector_only_mode],
217+
"max_new_tokens": [2048],
218+
}
167219
inputs.append(
168220
{
169-
"prompt": prompt,
170-
"additional_information": {
171-
"task_type": [task_type],
172-
"ref_audio": [ref_audio_single],
173-
"ref_text": [ref_text_single],
174-
"text": [text],
175-
"language": [language],
176-
"x_vector_only_mode": [x_vector_only_mode],
177-
"max_new_tokens": [2048],
178-
},
221+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
222+
"additional_information": additional_information,
179223
}
180224
)
181225
else:
182-
prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n"
226+
additional_information = {
227+
"task_type": [task_type],
228+
"ref_audio": [ref_audio_single],
229+
"ref_text": [ref_text_single],
230+
"text": [syn_text_single],
231+
"language": [syn_lang_single],
232+
"x_vector_only_mode": [x_vector_only_mode],
233+
"max_new_tokens": [2048],
234+
}
183235
inputs = {
184-
"prompt": prompt,
185-
"additional_information": {
186-
"task_type": [task_type],
187-
"ref_audio": [ref_audio_single],
188-
"ref_text": [ref_text_single],
189-
"text": [syn_text_single],
190-
"language": [syn_lang_single],
191-
"x_vector_only_mode": [x_vector_only_mode],
192-
"max_new_tokens": [2048],
193-
},
236+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
237+
"additional_information": additional_information,
194238
}
195239
return QueryResult(
196240
inputs=inputs,
197-
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
241+
model_name=model_name,
198242
)
199243

200244

@@ -223,30 +267,22 @@ def main(args):
223267
stage_init_timeout=args.stage_init_timeout,
224268
)
225269

226-
sampling_params = SamplingParams(
227-
temperature=0.9,
228-
top_p=1.0,
229-
top_k=50,
230-
max_tokens=2048,
231-
seed=42,
232-
detokenize=False,
233-
repetition_penalty=1.05,
234-
)
235-
236-
sampling_params_list = [
237-
sampling_params,
238-
]
239-
240270
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
241271
os.makedirs(output_dir, exist_ok=True)
242272

243-
omni_generator = omni.generate(query_result.inputs, sampling_params_list)
273+
omni_generator = omni.generate(query_result.inputs, sampling_params_list=None)
244274
for stage_outputs in omni_generator:
245275
for output in stage_outputs.request_output:
246276
request_id = output.request_id
247-
audio_tensor = output.outputs[0].multimodal_output["audio"]
277+
audio_data = output.outputs[0].multimodal_output["audio"]
278+
# async_chunk mode returns a list of chunks; concatenate them.
279+
if isinstance(audio_data, list):
280+
audio_tensor = torch.cat(audio_data, dim=-1)
281+
else:
282+
audio_tensor = audio_data
248283
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
249-
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
284+
sr_val = output.outputs[0].multimodal_output["sr"]
285+
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
250286
# Convert to numpy array and ensure correct format
251287
audio_numpy = audio_tensor.float().detach().cpu().numpy()
252288

tests/entrypoints/openai_api/test_serving_speech.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,12 @@ def test_is_tts_model(self, speech_server):
310310
speech_server.engine_client.stage_list = [mock_stage]
311311
assert speech_server._is_tts_model() is True
312312

313-
def test_build_tts_prompt(self, speech_server):
314-
"""Test TTS prompt format."""
315-
prompt = speech_server._build_tts_prompt("Hello")
316-
assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n"
313+
def test_estimate_prompt_len_fallback(self, speech_server):
314+
"""Test prompt length estimation falls back to 2048 when model is unavailable."""
315+
tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]}
316+
result = speech_server._estimate_prompt_len(tts_params)
317+
# Without a real model, it should fall back to 2048.
318+
assert result == 2048
317319

318320
def test_validate_tts_request_basic(self, speech_server):
319321
"""Test basic validation cases."""

tests/worker/test_omni_gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
6969
runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
7070

7171
runner.talker_mtp = DummyTalkerMTP()
72+
runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes")
7273
runner.vllm_config = object()
7374

7475
# Provide a minimal implementation that returns the expected 4-tuple.

vllm_omni/config/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class OmniModelConfig(ModelConfig):
5858
}
5959
)
6060
omni_kv_config: dict | None = None
61+
codec_frame_rate_hz: float | None = None
6162

6263
@property
6364
def registry(self):
@@ -128,6 +129,21 @@ def __post_init__(
128129
video_pruning_rate=video_pruning_rate,
129130
)
130131

132+
# Qwen3-TTS: infer codec frame rate from the model config for online serving.
133+
if self.codec_frame_rate_hz is None and self.model_arch == "Qwen3TTSTalkerForConditionalGenerationARVLLM":
134+
talker_cfg = getattr(self.hf_config, "talker_config", None)
135+
if isinstance(talker_cfg, dict):
136+
pos_per_sec = talker_cfg.get("position_id_per_seconds")
137+
else:
138+
pos_per_sec = getattr(talker_cfg, "position_id_per_seconds", None)
139+
if pos_per_sec is not None:
140+
try:
141+
fps = float(pos_per_sec)
142+
except Exception:
143+
fps = None
144+
if fps is not None and fps > 0:
145+
self.codec_frame_rate_hz = fps
146+
131147
# Override hf_text_config with omni-specific logic for multi-stage models
132148
# (e.g., thinker_config, talker_config)
133149
new_hf_text_config = self.draw_hf_text_config()

0 commit comments

Comments
 (0)