Skip to content

Commit 4de077e

Browse files
JuanPZuluagapablo
andauthored
[Qwen3TTS][Feat] Code2Wav batched decoding (vllm-project#1426)
Signed-off-by: pablo <pablo@agigo.ai> Co-authored-by: pablo <pablo@agigo.ai>
1 parent 82e1bf2 commit 4de077e

File tree

5 files changed

+326
-99
lines changed

5 files changed

+326
-99
lines changed

examples/offline_inference/qwen3_tts/README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,20 @@ Examples:
8787
python end2end.py --query-type Base --mode-tag icl
8888
```
8989

90+
## Batched Decoding
91+
92+
The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_batch_size > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
93+
94+
```
95+
python end2end.py --query-type CustomVoice \
96+
--txt-prompts benchmark_prompts.txt \
97+
--batch-size 4 \
98+
--stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
99+
```
100+
101+
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_batch_size >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_batch_size`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
102+
90103
## Notes
91104

92105
- The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs.
93-
- Use `--output-dir` (preferred) or `--output-wav` to change the output folder.
106+
- Use `--output-dir` to change the output folder.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Hello, welcome to the voice synthesis benchmark test.
2+
She said she would be here by noon, but nobody showed up.
3+
The quick brown fox jumps over the lazy dog near the riverbank.
4+
I can't believe how beautiful the sunset looks from up here on the mountain.
5+
Please remember to bring your identification documents to the appointment tomorrow morning.
6+
Have you ever wondered what it would be like to travel through time and visit ancient civilizations?
7+
The restaurant on the corner serves the best pasta I have ever tasted in my entire life.
8+
After the meeting, we should discuss the quarterly results and plan for the next phase.
9+
Learning a new language takes patience, practice, and a genuine curiosity about other cultures.
10+
The train leaves at half past seven, so we need to arrive at the station before then.
11+
Could you please turn down the music a little bit, I'm trying to concentrate on my work.
12+
It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.

examples/offline_inference/qwen3_tts/end2end.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def main(args):
248248
Args:
249249
args: Parsed CLI args from parse_args().
250250
"""
251+
if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0:
252+
raise ValueError(
253+
f"--batch-size must be a power of two (got {args.batch_size}); "
254+
"non-power-of-two values do not align with CUDA graph capture sizes "
255+
"of Code2Wav."
256+
)
257+
251258
query_func = query_map[args.query_type]
252259
if args.query_type in {"CustomVoice", "VoiceDesign"}:
253260
query_result = query_func(use_batch_sample=args.use_batch_sample)
@@ -260,39 +267,69 @@ def main(args):
260267
query_result = query_func()
261268

262269
model_name = query_result.model_name
270+
271+
# Load prompts from text file if provided.
272+
# Use the default query as a template so task-specific fields
273+
# (e.g. ref_audio for Base) are preserved; only override text.
274+
if args.txt_prompts:
275+
with open(args.txt_prompts) as f:
276+
lines = [line.strip() for line in f if line.strip()]
277+
if not lines:
278+
raise ValueError(f"No valid prompts found in {args.txt_prompts}")
279+
template = query_result.inputs
280+
if isinstance(template, list):
281+
template = template[0]
282+
template_info = template["additional_information"]
283+
inputs = []
284+
for text in lines:
285+
additional_information = {**template_info, "text": [text]}
286+
inputs.append(
287+
{
288+
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
289+
"additional_information": additional_information,
290+
}
291+
)
292+
else:
293+
inputs = query_result.inputs
294+
if not isinstance(inputs, list):
295+
inputs = [inputs]
296+
263297
omni = Omni(
264298
model=model_name,
265299
stage_configs_path=args.stage_configs_path,
266300
log_stats=args.log_stats,
267301
stage_init_timeout=args.stage_init_timeout,
268302
)
269303

270-
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
304+
output_dir = args.output_dir
271305
os.makedirs(output_dir, exist_ok=True)
272306

273-
omni_generator = omni.generate(query_result.inputs, sampling_params_list=None)
274-
for stage_outputs in omni_generator:
275-
for output in stage_outputs.request_output:
276-
request_id = output.request_id
277-
audio_data = output.outputs[0].multimodal_output["audio"]
278-
# async_chunk mode returns a list of chunks; concatenate them.
279-
if isinstance(audio_data, list):
280-
audio_tensor = torch.cat(audio_data, dim=-1)
281-
else:
282-
audio_tensor = audio_data
283-
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
284-
sr_val = output.outputs[0].multimodal_output["sr"]
285-
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
286-
# Convert to numpy array and ensure correct format
287-
audio_numpy = audio_tensor.float().detach().cpu().numpy()
288-
289-
# Ensure audio is 1D (flatten if needed)
290-
if audio_numpy.ndim > 1:
291-
audio_numpy = audio_numpy.flatten()
292-
293-
# Save audio file with explicit WAV format
294-
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
295-
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
307+
batch_size = args.batch_size
308+
for batch_start in range(0, len(inputs), batch_size):
309+
batch = inputs[batch_start : batch_start + batch_size]
310+
omni_generator = omni.generate(batch, sampling_params_list=None)
311+
for stage_outputs in omni_generator:
312+
for output in stage_outputs.request_output:
313+
request_id = output.request_id
314+
audio_data = output.outputs[0].multimodal_output["audio"]
315+
# async_chunk mode returns a list of chunks; concatenate them.
316+
if isinstance(audio_data, list):
317+
audio_tensor = torch.cat(audio_data, dim=-1)
318+
else:
319+
audio_tensor = audio_data
320+
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
321+
sr_val = output.outputs[0].multimodal_output["sr"]
322+
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
323+
# Convert to numpy array and ensure correct format
324+
audio_numpy = audio_tensor.float().detach().cpu().numpy()
325+
326+
# Ensure audio is 1D (flatten if needed)
327+
if audio_numpy.ndim > 1:
328+
audio_numpy = audio_numpy.flatten()
329+
330+
# Save audio file with explicit WAV format
331+
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
332+
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
296333

297334

298335
def parse_args():
@@ -341,9 +378,9 @@ def parse_args():
341378
help="Threshold for using shared memory in bytes (default: 65536)",
342379
)
343380
parser.add_argument(
344-
"--output-wav",
381+
"--output-dir",
345382
default="output_audio",
346-
help="[Deprecated] Output wav directory (use --output-dir).",
383+
help="Output directory for generated wav files (default: output_audio).",
347384
)
348385
parser.add_argument(
349386
"--num-prompts",
@@ -401,6 +438,12 @@ def parse_args():
401438
choices=["icl", "xvec_only"],
402439
help="Mode tag for Base query x_vector_only_mode (default: icl).",
403440
)
441+
parser.add_argument(
442+
"--batch-size",
443+
type=int,
444+
default=1,
445+
help="Number of prompts per batch (default: 1, sequential).",
446+
)
404447

405448
return parser.parse_args()
406449

0 commit comments

Comments
 (0)