Skip to content

Commit 29ee640

Browse files
aaron-boxerboxerab
authored andcommitted
api: add support OpenAI REST transcription api
1 parent b9ae2af commit 29ee640

File tree

5 files changed

+218
-7
lines changed

5 files changed

+218
-7
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ source whisper_env/bin/activate
4545
```
4646

4747

48+
### OpenAI REST interface
49+
50+
#### Server
51+
52+
```bash
53+
python3 run_server.py --port 9090 --backend faster_whisper --max_clients 4 --max_connection_time 600 --enable_rest --cors-origins="http://localhost:8080,http://127.0.0.1:8080"
54+
```
55+
56+
#### Client
57+
58+
```bash
59+
python3 client_openai.py $AUDIO_FILE
60+
```
61+
62+
63+
4864
### Setting up NVIDIA/TensorRT-LLM for TensorRT backend
4965
- Please follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup of [NVIDIA/TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and for building Whisper-TensorRT engine.
5066

client_openai.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import sys
2+
import requests
3+
4+
if len(sys.argv) < 2:
5+
print("Usage: python transcribe_file.py <path_to_audio_file>")
6+
sys.exit(1)
7+
8+
audio_file = sys.argv[1]
9+
10+
# Configuration
11+
host = "localhost"
12+
port = 8000 # Default REST port; change if you used --rest_port
13+
url = f"http://{host}:{port}/v1/audio/transcriptions"
14+
model = "small" # Or "whisper-1" (mapped to small internally)
15+
language = "en" # Or "hi" for Hindi
16+
response_format = "json" # Options: "json", "text", "verbose_json", "srt", "vtt"
17+
18+
# Prepare the request
19+
files = {"file": open(audio_file, "rb")}
20+
data = {
21+
"model": model,
22+
"language": language,
23+
"response_format": response_format,
24+
# Optional: Add "prompt" for style guidance, "temperature" (0-1), etc.
25+
}
26+
27+
# Send the request
28+
response = requests.post(url, files=files, data=data)
29+
30+
if response.status_code == 200:
31+
if response_format == "json" or response_format == "verbose_json":
32+
result = response.json()
33+
print("Transcript:", result.get("text", "No text found"))
34+
# If you need translation, post-process here (e.g., using another API like Google Translate)
35+
else:
36+
print("Transcript:", response.text)
37+
else:
38+
print("Error:", response.status_code, response.json().get("error", "Unknown error"))

requirements/server.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,8 @@ openvino
2020
openvino-genai
2121
openvino-tokenizers
2222
optimum
23-
optimum-intel
23+
optimum-intel
24+
25+
fastapi
26+
uvicorn
27+
python-multipart

run_server.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
import argparse
22
import os
3+
import threading
4+
import logging
5+
from fastapi import FastAPI
6+
from fastapi import UploadFile, Form
7+
import uvicorn
8+
import tempfile
9+
import shutil
10+
import json
11+
from starlette.responses import PlainTextResponse, JSONResponse
312

413
if __name__ == "__main__":
514
parser = argparse.ArgumentParser()
@@ -43,6 +52,20 @@
4352
type=str,
4453
default="~/.cache/whisper-live/",
4554
help='Path to cache the converted ctranslate2 models.')
55+
parser.add_argument(
56+
"--rest_port", type=int, default=8000, help="Port for the REST API server."
57+
)
58+
parser.add_argument(
59+
"--enable_rest",
60+
action="store_true",
61+
help="Enable the OpenAI-compatible REST API endpoint.",
62+
)
63+
parser.add_argument(
64+
'--cors-origins',
65+
type=str,
66+
default=None,
67+
help="Comma-separated list of allowed CORS origins (e.g., 'http://localhost:3000,http://example.com'). Defaults to localhost/127.0.0.1 on the WebSocket port."
68+
)
4669
args = parser.parse_args()
4770

4871
if args.backend == "tensorrt":
@@ -65,5 +88,8 @@
6588
single_model=not args.no_single_model,
6689
max_clients=args.max_clients,
6790
max_connection_time=args.max_connection_time,
68-
cache_path=args.cache_path
69-
)
91+
cache_path=args.cache_path,
92+
rest_port=args.rest_port,
93+
enable_rest=args.enable_rest,
94+
cors_origins=args.cors_origins,
95+
)

whisper_live/server.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,18 @@
55
import json
66
import functools
77
import logging
8+
import shutil
9+
import tempfile
10+
from typing import Optional, List
11+
from fastapi import FastAPI, UploadFile, Form
12+
from fastapi.middleware.cors import CORSMiddleware
13+
from starlette.responses import PlainTextResponse, JSONResponse
14+
import uvicorn
15+
from faster_whisper import WhisperModel
16+
import torch
17+
818
from enum import Enum
919
from typing import List, Optional
10-
1120
import numpy as np
1221
from websockets.sync.server import serve
1322
from websockets.exceptions import ConnectionClosed
@@ -403,7 +412,10 @@ def run(self,
403412
single_model=False,
404413
max_clients=4,
405414
max_connection_time=600,
406-
cache_path="~/.cache/whisper-live/"):
415+
cache_path="~/.cache/whisper-live/",
416+
rest_port=8000,
417+
enable_rest=False,
418+
cors_origins: Optional[str] = None):
407419
"""
408420
Run the transcription server.
409421
@@ -427,6 +439,122 @@ def run(self,
427439
logging.info("Single model mode currently only works with custom models.")
428440
if not BackendType.is_valid(backend):
429441
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
442+
443+
# New OpenAI-compatible REST API (toggleable via enable_rest boolean)
444+
if enable_rest:
445+
app = FastAPI(title="WhisperLive OpenAI-Compatible API")
446+
origins = [o.strip() for o in cors_origins.split(',')] if cors_origins else []
447+
app.add_middleware(
448+
CORSMiddleware,
449+
allow_origins=origins,
450+
allow_credentials=True,
451+
allow_methods=["*"], # Allows all methods (GET, POST, etc.)
452+
allow_headers=["*"], # Allows all headers
453+
)
454+
455+
456+
@app.post("/v1/audio/transcriptions")
457+
async def transcribe(
458+
file: UploadFile,
459+
model: str = Form(default="whisper-1"),
460+
language: Optional[str] = Form(default=None),
461+
prompt: Optional[str] = Form(default=None),
462+
response_format: str = Form(default="json"),
463+
temperature: float = Form(default=0.0),
464+
timestamp_granularities: Optional[List[str]] = Form(default=None),
465+
# Stubs for unsupported OpenAI params
466+
chunking_strategy: Optional[str] = Form(default=None),
467+
include: Optional[List[str]] = Form(default=None),
468+
known_speaker_names: Optional[List[str]] = Form(default=None),
469+
known_speaker_references: Optional[List[str]] = Form(default=None),
470+
stream: bool = Form(default=False)
471+
):
472+
if stream:
473+
return JSONResponse({"error": "Streaming not supported in this backend."}, status_code=400)
474+
if chunking_strategy or known_speaker_names or known_speaker_references:
475+
logging.warning("Diarization/chunking params ignored; not supported.")
476+
477+
supported_formats = ["json", "text", "srt", "verbose_json", "vtt"]
478+
if response_format not in supported_formats:
479+
return JSONResponse({"error": f"Unsupported response_format. Supported: {supported_formats}"}, status_code=400)
480+
481+
if model != "whisper-1":
482+
logging.warning(f"Model '{model}' requested; using 'small' as fallback.")
483+
model_name = faster_whisper_custom_model_path or "small"
484+
485+
try:
486+
suffix = os.path.splitext(file.filename)[1] or ".wav"
487+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
488+
shutil.copyfileobj(file.file, tmp)
489+
tmp_path = tmp.name
490+
491+
device = "cuda" if torch.cuda.is_available() else "cpu"
492+
compute_type = "float16" if device == "cuda" else "int8"
493+
494+
transcriber = WhisperModel(model_name, device=device, compute_type=compute_type)
495+
segments, info = transcriber.transcribe(
496+
tmp_path,
497+
language=language,
498+
initial_prompt=prompt,
499+
temperature=temperature,
500+
vad_filter=False,
501+
word_timestamps=(timestamp_granularities and "word" in timestamp_granularities)
502+
)
503+
504+
text = " ".join([s.text.strip() for s in segments])
505+
os.unlink(tmp_path)
506+
507+
if response_format == "text":
508+
return PlainTextResponse(text)
509+
elif response_format == "json":
510+
return {"text": text}
511+
elif response_format == "verbose_json":
512+
verbose = {
513+
"task": "transcribe",
514+
"language": info.language,
515+
"duration": info.duration,
516+
"text": text,
517+
"segments": []
518+
}
519+
for seg in segments:
520+
seg_dict = {
521+
"id": seg.id,
522+
"seek": seg.seek,
523+
"start": seg.start,
524+
"end": seg.end,
525+
"text": seg.text.strip(),
526+
"tokens": seg.tokens,
527+
"temperature": seg.temperature,
528+
"avg_logprob": seg.avg_logprob,
529+
"compression_ratio": seg.compression_ratio,
530+
"no_speech_prob": seg.no_speech_prob
531+
}
532+
if timestamp_granularities and "word" in timestamp_granularities:
533+
seg_dict["words"] = [{"word": w.word, "start": w.start, "end": w.end, "probability": w.probability} for w in seg.words]
534+
verbose["segments"].append(seg_dict)
535+
return verbose
536+
elif response_format in ["srt", "vtt"]:
537+
output = []
538+
for i, seg in enumerate(segments, 1):
539+
start = f"{int(seg.start // 3600):02}:{int((seg.start % 3600) // 60):02}:{seg.start % 60:06.3f}"
540+
end = f"{int(seg.end // 3600):02}:{int((seg.end % 3600) // 60):02}:{seg.end % 60:06.3f}"
541+
if response_format == "srt":
542+
output.append(f"{i}\n{start.replace('.', ',')} --> {end.replace('.', ',')}\n{seg.text.strip()}\n")
543+
else: # vtt
544+
output.append(f"{start} --> {end}\n{seg.text.strip()}\n")
545+
return PlainTextResponse("\n".join(output))
546+
except Exception as e:
547+
return JSONResponse({"error": str(e)}, status_code=500)
548+
549+
threading.Thread(
550+
target=uvicorn.run,
551+
args=(app,),
552+
kwargs={"host": "0.0.0.0", "port": rest_port, "log_level": "info"},
553+
daemon=True
554+
).start()
555+
logging.info(f"✅ OpenAI-Compatible API started on http://0.0.0.0:{rest_port}")
556+
557+
# Original WebSocket server (always supported)
430558
with serve(
431559
functools.partial(
432560
self.recv_audio,
@@ -486,5 +614,4 @@ def cleanup(self, websocket):
486614
# Wait for translation thread to finish
487615
if hasattr(client, 'translation_thread') and client.translation_thread:
488616
client.translation_thread.join(timeout=2.0)
489-
self.client_manager.remove_client(websocket)
490-
617+
self.client_manager.remove_client(websocket)

0 commit comments

Comments
 (0)