Skip to content

Commit adcf682

Browse files
[Audio] Improve Audio Inference Scripts (offline/online) (vllm-project#29279)
Signed-off-by: Ekagra Ranjan <[email protected]>
1 parent 21de6d4 commit adcf682

File tree

2 files changed

+113
-32
lines changed

2 files changed

+113
-32
lines changed

examples/offline_inference/audio_language.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -495,27 +495,40 @@ def main(args):
495495
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
496496
)
497497

498-
mm_data = req_data.multi_modal_data
499-
if not mm_data:
500-
mm_data = {}
501-
if audio_count > 0:
502-
mm_data = {
503-
"audio": [
504-
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
505-
]
506-
}
507-
498+
def get_input(start, end):
499+
mm_data = req_data.multi_modal_data
500+
if not mm_data:
501+
mm_data = {}
502+
if end - start > 0:
503+
mm_data = {
504+
"audio": [
505+
asset.audio_and_sample_rate for asset in audio_assets[start:end]
506+
]
507+
}
508+
509+
inputs = {"multi_modal_data": mm_data}
510+
511+
if req_data.prompt:
512+
inputs["prompt"] = req_data.prompt
513+
else:
514+
inputs["prompt_token_ids"] = req_data.prompt_token_ids
515+
516+
return inputs
517+
518+
# Batch inference
508519
assert args.num_prompts > 0
509-
inputs = {"multi_modal_data": mm_data}
510-
511-
if req_data.prompt:
512-
inputs["prompt"] = req_data.prompt
520+
if audio_count != 1:
521+
inputs = get_input(0, audio_count)
522+
inputs = [inputs] * args.num_prompts
513523
else:
514-
inputs["prompt_token_ids"] = req_data.prompt_token_ids
524+
# For single audio input, we need to vary the audio input
525+
# to avoid deduplication in vLLM engine.
526+
inputs = []
527+
for i in range(args.num_prompts):
528+
start = i % len(audio_assets)
529+
inp = get_input(start, start + 1)
530+
inputs.append(inp)
515531

516-
if args.num_prompts > 1:
517-
# Batch inference
518-
inputs = [inputs] * args.num_prompts
519532
# Add LoRA request if applicable
520533
lora_request = (
521534
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None

examples/online_serving/openai_transcription_client.py

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@
1818
2. Streaming transcription using raw HTTP request to the vLLM server.
1919
"""
2020

21+
import argparse
2122
import asyncio
2223

2324
from openai import AsyncOpenAI, OpenAI
2425

2526
from vllm.assets.audio import AudioAsset
2627

2728

28-
def sync_openai(audio_path: str, client: OpenAI):
29+
def sync_openai(audio_path: str, client: OpenAI, model: str):
2930
"""
3031
Perform synchronous transcription using OpenAI-compatible API.
3132
"""
3233
with open(audio_path, "rb") as f:
3334
transcription = client.audio.transcriptions.create(
3435
file=f,
35-
model="openai/whisper-large-v3",
36+
model=model,
3637
language="en",
3738
response_format="json",
3839
temperature=0.0,
@@ -42,18 +43,18 @@ def sync_openai(audio_path: str, client: OpenAI):
4243
repetition_penalty=1.3,
4344
),
4445
)
45-
print("transcription result:", transcription.text)
46+
print("transcription result [sync]:", transcription.text)
4647

4748

48-
async def stream_openai_response(audio_path: str, client: AsyncOpenAI):
49+
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
4950
"""
5051
Perform asynchronous transcription using OpenAI-compatible API.
5152
"""
52-
print("\ntranscription result:", end=" ")
53+
print("\ntranscription result [stream]:", end=" ")
5354
with open(audio_path, "rb") as f:
5455
transcription = await client.audio.transcriptions.create(
5556
file=f,
56-
model="openai/whisper-large-v3",
57+
model=model,
5758
language="en",
5859
response_format="json",
5960
temperature=0.0,
@@ -72,7 +73,47 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI):
7273
print() # Final newline after stream ends
7374

7475

75-
def main():
76+
def stream_api_response(audio_path: str, model: str, openai_api_base: str):
77+
"""
78+
Perform streaming transcription using raw HTTP requests to the vLLM API server.
79+
"""
80+
import json
81+
import os
82+
83+
import requests
84+
85+
api_url = f"{openai_api_base}/audio/transcriptions"
86+
headers = {"User-Agent": "Transcription-Client"}
87+
with open(audio_path, "rb") as f:
88+
files = {"file": (os.path.basename(audio_path), f)}
89+
data = {
90+
"stream": "true",
91+
"model": model,
92+
"language": "en",
93+
"response_format": "json",
94+
}
95+
96+
print("\ntranscription result [stream]:", end=" ")
97+
response = requests.post(
98+
api_url, headers=headers, files=files, data=data, stream=True
99+
)
100+
for chunk in response.iter_lines(
101+
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
102+
):
103+
if chunk:
104+
data = chunk[len("data: ") :]
105+
data = json.loads(data.decode("utf-8"))
106+
data = data["choices"][0]
107+
delta = data["delta"]["content"]
108+
print(delta, end="", flush=True)
109+
110+
finish_reason = data.get("finish_reason")
111+
if finish_reason is not None:
112+
print(f"\n[Stream finished reason: {finish_reason}]")
113+
break
114+
115+
116+
def main(args):
76117
mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
77118
winning_call = str(AudioAsset("winning_call").get_local_path())
78119

@@ -84,14 +125,41 @@ def main():
84125
base_url=openai_api_base,
85126
)
86127

87-
sync_openai(mary_had_lamb, client)
128+
model = client.models.list().data[0].id
129+
print(f"Using model: {model}")
130+
131+
# Run the synchronous function
132+
sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model)
133+
88134
# Run the asynchronous function
89-
client = AsyncOpenAI(
90-
api_key=openai_api_key,
91-
base_url=openai_api_base,
92-
)
93-
asyncio.run(stream_openai_response(winning_call, client))
135+
if "openai" in model:
136+
client = AsyncOpenAI(
137+
api_key=openai_api_key,
138+
base_url=openai_api_base,
139+
)
140+
asyncio.run(
141+
stream_openai_response(
142+
args.audio_path if args.audio_path else winning_call, client, model
143+
)
144+
)
145+
else:
146+
stream_api_response(
147+
args.audio_path if args.audio_path else winning_call,
148+
model,
149+
openai_api_base,
150+
)
94151

95152

96153
if __name__ == "__main__":
97-
main()
154+
# setup argparser
155+
parser = argparse.ArgumentParser(
156+
description="OpenAI Transcription Client using vLLM API Server"
157+
)
158+
parser.add_argument(
159+
"--audio_path",
160+
type=str,
161+
default=None,
162+
help="The path to the audio file to transcribe.",
163+
)
164+
args = parser.parse_args()
165+
main(args)

0 commit comments

Comments
 (0)