Skip to content

Commit 3a46b24

Browse files
davidgao7YuhanLiu11
authored andcommitted
[feat]: add transcription API endpoint using OpenAI Whisper-small (vllm-project#469)
* [feat]: add transcription API endpoint using OpenAI Whisper-small Signed-off-by: David Gao <[email protected]> * remove the whisper payload response log Signed-off-by: David Gao <[email protected]> * [docs]: add tutorial for transcription v1 api Signed-off-by: David Gao <[email protected]> * [chore] align example router running script with main new script will be mentioned in `tutorials/17-whisper-api-transcription.md` Signed-off-by: David Gao <[email protected]> * omit model field since backend already knows which model to run Signed-off-by: David Gao <[email protected]> * generate a silent audio file if no audio file appears Signed-off-by: David Gao <[email protected]> * put wav creation at the module level to prevent being recreated every time Signed-off-by: David Gao <[email protected]> * [Test] test frequency of silent audio creation Signed-off-by: David Gao <[email protected]> * send multipart/form-data for transcription model's health check Signed-off-by: David Gao <[email protected]> * fix pre-commit issue Signed-off-by: David Gao <[email protected]> * Moves the implementation for the `/v1/audio/transcriptions` endpoint from `main_router.py` into `request.py`, align architectural pattern. Signed-off-by: David Gao <[email protected]> * add timeout to ensure health check will not hang indefinitely if a backend model becomes unresponsive Signed-off-by: David Gao <[email protected]> * add boolean model health check return for non-transcription model Signed-off-by: David Gao <[email protected]> * remove redundant warning log since handled in outer 'StaticServiceDiscovery.get_unhealthy_endpoint_hashes' Signed-off-by: David Gao <[email protected]> * remove redundant JSONDecodeError catch and downgrade RequestException log to debug, align with service discovery's warning Signed-off-by: David Gao <[email protected]> * Chore: Apply auto-formatting and linting fixes via pre-commit Signed-off-by: David Gao <[email protected]> * refactor: update more meaningful comments for silent wav bytes generation Signed-off-by: David Gao <[email protected]> * refactor: keep the comment to explain purpose for generating a silent WAV byte Signed-off-by: David Gao <[email protected]> * fix(tests): Improve mock in model health check test The mock for `requests.post` in `test_is_model_healthy` did not correctly simulate an `HTTPError` on a non-200 response. This change configures the mock's `raise_for_status` method to raise the appropriate exception, ensuring the test now accurately validates the function's error handling logic. Signed-off-by: David Gao <[email protected]> * Chore: Apply auto-formatting and linting fixes via pre-commit Signed-off-by: David Gao <[email protected]> * chore: remove unused var `in_router_time` Signed-off-by: David Gao <[email protected]> * fix: (deps) add httpx as an explicit dependency The CI/CD workflow was failing with a `ModuleNotFoundError` because `httpx` was not an explicit dependency in `pyproject.toml` and was not being installed in the clean Docker environment. Signed-off-by: David Gao <[email protected]> * chore: dependencies order changes after running pre-commit Signed-off-by: David Gao <[email protected]> * refactor: Migration from httpx to aiohttp for improved concurrency Replaced httpx with aiohttp for better asynchronous performance and resource utilization. Fixed JSON syntax error in error response handling. Signed-off-by: David Gao <[email protected]> * chore: remove wrong tutorial file Signed-off-by: David Gao <[email protected]> * chore: apply pre-commit Signed-off-by: David Gao <[email protected]> * chore: use debug log print Signed-off-by: David Gao <[email protected]> * chore: change to more specific exception handling for aiohttp Signed-off-by: David Gao <[email protected]> --------- Signed-off-by: David Gao <[email protected]> Co-authored-by: Yuhan Liu <[email protected]> Signed-off-by: Ifta Khairul Alam Adil <[email protected]>
1 parent 646cf5c commit 3a46b24

File tree

5 files changed

+366
-19
lines changed

5 files changed

+366
-19
lines changed

src/tests/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def test_is_model_healthy_when_requests_raises_exception_returns_false(
104104
def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
105105
monkeypatch: pytest.MonkeyPatch,
106106
) -> None:
107-
request_mock = MagicMock(return_value=MagicMock(status_code=500))
107+
108+
# Mock an internal server error response
109+
mock_response = MagicMock(status_code=500)
110+
111+
# Tell the mock to raise an HTTP Error when raise_for_status() is called
112+
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError
113+
114+
request_mock = MagicMock(return_value=mock_response)
108115
monkeypatch.setattr("requests.post", request_mock)
116+
109117
assert utils.is_model_healthy("http://localhost", "test", "chat") is False

src/vllm_router/routers/main_router.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
# limitations under the License.
1414
import json
1515

16-
from fastapi import APIRouter, BackgroundTasks, Request
16+
from fastapi import (
17+
APIRouter,
18+
BackgroundTasks,
19+
Request,
20+
)
1721
from fastapi.responses import JSONResponse, Response
1822

1923
from vllm_router.dynamic_config import get_dynamic_config_watcher
@@ -22,6 +26,7 @@
2226
from vllm_router.service_discovery import get_service_discovery
2327
from vllm_router.services.request_service.request import (
2428
route_general_request,
29+
route_general_transcriptions,
2530
route_sleep_wakeup_request,
2631
)
2732
from vllm_router.stats.engine_stats import get_engine_stats_scraper
@@ -123,7 +128,7 @@ async def show_version():
123128
@main_router.get("/v1/models")
124129
async def show_models():
125130
"""
126-
Returns a list of all models available in the stack
131+
Returns a list of all models available in the stack.
127132
128133
Args:
129134
None
@@ -229,3 +234,13 @@ async def health() -> Response:
229234
)
230235
else:
231236
return JSONResponse(content={"status": "healthy"}, status_code=200)
237+
238+
239+
@main_router.post("/v1/audio/transcriptions")
240+
async def route_v1_audio_transcriptions(
241+
request: Request, background_tasks: BackgroundTasks
242+
):
243+
"""Handles audio transcription requests."""
244+
return await route_general_transcriptions(
245+
request, "/v1/audio/transcriptions", background_tasks
246+
)

src/vllm_router/services/request_service/request.py

Lines changed: 183 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import os
1818
import time
1919
import uuid
20+
from typing import Optional
2021

2122
import aiohttp
22-
from fastapi import BackgroundTasks, HTTPException, Request
23+
from fastapi import BackgroundTasks, HTTPException, Request, UploadFile
2324
from fastapi.responses import JSONResponse, StreamingResponse
2425
from requests import JSONDecodeError
2526

@@ -304,9 +305,7 @@ async def route_general_request(
304305
async def send_request_to_prefiller(
305306
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str
306307
):
307-
"""
308-
Send a request to a prefiller service.
309-
"""
308+
"""Send a request to a prefiller service."""
310309
req_data = req_data.copy()
311310
req_data["max_tokens"] = 1
312311
if "max_completion_tokens" in req_data:
@@ -325,9 +324,7 @@ async def send_request_to_prefiller(
325324
async def send_request_to_decode(
326325
client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str
327326
):
328-
"""
329-
Asynchronously stream the response from a service using a persistent client.
330-
"""
327+
"""Asynchronously stream the response from a service using a persistent client."""
331328
headers = {
332329
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
333330
"X-Request-Id": request_id,
@@ -511,3 +508,182 @@ async def route_sleep_wakeup_request(
511508
content={"status": "success"},
512509
headers={"X-Request-Id": request_id},
513510
)
511+
512+
513+
async def route_general_transcriptions(
514+
request: Request,
515+
endpoint: str, # "/v1/audio/transcriptions"
516+
background_tasks: BackgroundTasks,
517+
):
518+
"""Handles audio transcription requests by parsing form data and proxying to backend."""
519+
520+
request_id = request.headers.get("X-Request-Id", str(uuid.uuid4()))
521+
522+
# --- 1. Form parsing ---
523+
try:
524+
form = await request.form()
525+
526+
# Extract parameters from the form data
527+
file: UploadFile = form["file"]
528+
model: str = form["model"]
529+
prompt: Optional[str] = form.get("prompt", None)
530+
response_format: Optional[str] = form.get("response_format", "json")
531+
temperature_str: Optional[str] = form.get("temperature", None)
532+
temperature: Optional[float] = (
533+
float(temperature_str) if temperature_str is not None else None
534+
)
535+
language: Optional[str] = form.get("language", "en")
536+
except KeyError as e:
537+
return JSONResponse(
538+
status_code=400,
539+
content={"error": f"Invalid request: missing '{e.args[0]}' in form data."},
540+
)
541+
542+
logger.debug("==== Enter audio_transcriptions ====")
543+
logger.debug("Received upload: %s (%s)", file.filename, file.content_type)
544+
logger.debug(
545+
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s",
546+
model,
547+
prompt,
548+
response_format,
549+
temperature,
550+
language,
551+
)
552+
553+
# --- 2. Service Discovery and Routing ---
554+
# Access singletons via request.app.state for consistent style
555+
service_discovery = (
556+
get_service_discovery()
557+
) # This one is often still accessed directly via its get function
558+
router = request.app.state.router # Access router from app.state
559+
engine_stats_scraper = (
560+
request.app.state.engine_stats_scraper
561+
) # Access engine_stats_scraper from app.state
562+
request_stats_monitor = (
563+
request.app.state.request_stats_monitor
564+
) # Access request_stats_monitor from app.state
565+
566+
endpoints = service_discovery.get_endpoint_info()
567+
568+
logger.debug("==== Total endpoints ====")
569+
logger.debug(endpoints)
570+
logger.debug("==== Total endpoints ====")
571+
572+
# filter the endpoints url by model name and label for transcriptions
573+
transcription_endpoints = [
574+
ep
575+
for ep in endpoints
576+
if model == ep.model_name
577+
and ep.model_label == "transcription"
578+
and not ep.sleep # Added ep.sleep == False
579+
]
580+
581+
logger.debug("====List of transcription endpoints====")
582+
logger.debug(transcription_endpoints)
583+
logger.debug("====List of transcription endpoints====")
584+
585+
if not transcription_endpoints:
586+
logger.error("No transcription backend available for model %s", model)
587+
return JSONResponse(
588+
status_code=404,
589+
content={"error": f"No transcription backend for model {model}"},
590+
)
591+
592+
# grab the current engine and request stats
593+
engine_stats = engine_stats_scraper.get_engine_stats()
594+
request_stats = request_stats_monitor.get_request_stats(time.time())
595+
596+
# pick one using the router's configured logic (roundrobin, least-loaded, etc.)
597+
chosen_url = router.route_request(
598+
transcription_endpoints,
599+
engine_stats,
600+
request_stats,
601+
request,
602+
)
603+
604+
logger.debug("Proxying transcription request to %s", chosen_url)
605+
606+
# --- 3. Prepare and Proxy the Request ---
607+
payload_bytes = await file.read()
608+
files = {"file": (file.filename, payload_bytes, file.content_type)}
609+
610+
data = {"model": model, "language": language}
611+
612+
if prompt:
613+
data["prompt"] = prompt
614+
615+
if response_format:
616+
data["response_format"] = response_format
617+
618+
if temperature is not None:
619+
data["temperature"] = str(temperature)
620+
621+
logger.info("Proxying transcription request for model %s to %s", model, chosen_url)
622+
623+
logger.debug("==== data payload keys ====")
624+
logger.debug(list(data.keys()))
625+
logger.debug("==== data payload keys ====")
626+
627+
try:
628+
client = request.app.state.aiohttp_client_wrapper()
629+
630+
form_data = aiohttp.FormData()
631+
632+
# add file data
633+
for key, (filename, content, content_type) in files.items():
634+
form_data.add_field(
635+
key, content, filename=filename, content_type=content_type
636+
)
637+
638+
# add from data
639+
for key, value in data.items():
640+
form_data.add_field(key, value)
641+
642+
backend_response = await client.post(
643+
f"{chosen_url}{endpoint}",
644+
data=form_data,
645+
timeout=aiohttp.ClientTimeout(total=300),
646+
)
647+
648+
# --- 4. Return the response ---
649+
response_content = await backend_response.json()
650+
headers = {
651+
k: v
652+
for k, v in backend_response.headers.items()
653+
if k.lower() not in ("content-encoding", "transfer-encoding", "connection")
654+
}
655+
656+
headers["X-Request-Id"] = request_id
657+
658+
return JSONResponse(
659+
content=response_content,
660+
status_code=backend_response.status,
661+
headers=headers,
662+
)
663+
except aiohttp.ClientResponseError as response_error:
664+
if response_error.response is not None:
665+
try:
666+
error_content = await response_error.response.json()
667+
except (
668+
aiohttp.ContentTypeError,
669+
json.JSONDecodeError,
670+
aiohttp.ClientError,
671+
):
672+
# If JSON parsing fails, get text content
673+
try:
674+
text_content = await response_error.response.text()
675+
error_content = {"error": text_content}
676+
except aiohttp.ClientError:
677+
error_content = {
678+
"error": f"HTTP {response_error.status}: {response_error.message}"
679+
}
680+
else:
681+
error_content = {
682+
"error": f"HTTP {response_error.status}: {response_error.message}"
683+
}
684+
return JSONResponse(status_code=response_error.status, content=error_content)
685+
except aiohttp.ClientError as client_error:
686+
return JSONResponse(
687+
status_code=503,
688+
content={"error": f"Failed to connect to backend: {str(client_error)}"},
689+
)

src/vllm_router/utils.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import abc
22
import enum
3+
import io
34
import json
45
import re
56
import resource
7+
import wave
68
from typing import Optional
79

810
import requests
@@ -13,6 +15,23 @@
1315

1416
logger = init_logger(__name__)
1517

18+
# prepare a WAV byte to prevent repeatedly generating it
19+
# Generate a 0.1 second silent audio file
20+
# This will be used for the /v1/audio/transcriptions endpoint
21+
_SILENT_WAV_BYTES = None
22+
with io.BytesIO() as wav_buffer:
23+
with wave.open(wav_buffer, "wb") as wf:
24+
wf.setnchannels(1) # mono audio channel, standard configuration
25+
wf.setsampwidth(2) # 16 bit audio, common bit depth for wav file
26+
wf.setframerate(16000) # 16 kHz sample rate
27+
wf.writeframes(b"\x00\x00" * 1600) # 0.1 second of silence
28+
29+
# retrieves the generated wav bytes, return
30+
_SILENT_WAV_BYTES = wav_buffer.getvalue()
31+
logger.debug(
32+
"======A default silent WAV file has been stored in memory within py application process===="
33+
)
34+
1635

1736
class SingletonMeta(type):
1837
_instances = {}
@@ -52,6 +71,7 @@ class ModelType(enum.Enum):
5271
embeddings = "/v1/embeddings"
5372
rerank = "/v1/rerank"
5473
score = "/v1/score"
74+
transcription = "/v1/audio/transcriptions"
5575

5676
@staticmethod
5777
def get_test_payload(model_type: str):
@@ -75,6 +95,12 @@ def get_test_payload(model_type: str):
7595
return {"query": "Hello", "documents": ["Test"]}
7696
case ModelType.score:
7797
return {"encoding_format": "float", "text_1": "Test", "test_2": "Test2"}
98+
case ModelType.transcription:
99+
if _SILENT_WAV_BYTES is not None:
100+
logger.debug("=====Silent WAV Bytes is being used=====")
101+
return {
102+
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"),
103+
}
78104

79105
@staticmethod
80106
def get_all_fields():
@@ -161,14 +187,37 @@ def update_content_length(request: Request, request_body: str):
161187

162188
def is_model_healthy(url: str, model: str, model_type: str) -> bool:
163189
model_details = ModelType[model_type]
190+
164191
try:
165-
response = requests.post(
166-
f"{url}{model_details.value}",
167-
headers={"Content-Type": "application/json"},
168-
json={"model": model} | model_details.get_test_payload(model_type),
169-
timeout=30,
170-
)
171-
except Exception as e:
172-
logger.error(e)
192+
if model_type == "transcription":
193+
194+
# for transcription, the backend expects multipart/form-data with a file
195+
# we will use pre-generated silent wav bytes
196+
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")}
197+
data = {"model": model}
198+
response = requests.post(
199+
f"{url}{model_details.value}",
200+
files=files, # multipart/form-data
201+
data=data,
202+
timeout=10,
203+
)
204+
else:
205+
# for other model types (chat, completion, etc.)
206+
response = requests.post(
207+
f"{url}{model_details.value}",
208+
headers={"Content-Type": "application/json"},
209+
json={"model": model} | model_details.get_test_payload(model_type),
210+
timeout=10,
211+
)
212+
213+
response.raise_for_status()
214+
215+
if model_type == "transcription":
216+
return True
217+
else:
218+
response.json() # verify it's valid json for other model types
219+
return True # validation passed
220+
221+
except requests.exceptions.RequestException as e:
222+
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}")
173223
return False
174-
return response.status_code == 200

0 commit comments

Comments
 (0)