55import json
66import functools
77import logging
8+ import shutil
9+ import tempfile
10+ from typing import Optional , List
11+ from fastapi import FastAPI , UploadFile , Form
12+ from fastapi .middleware .cors import CORSMiddleware
13+ from starlette .responses import PlainTextResponse , JSONResponse
14+ import uvicorn
15+ from faster_whisper import WhisperModel
16+ import torch
17+
818from enum import Enum
919from typing import List , Optional
10-
1120import numpy as np
1221from websockets .sync .server import serve
1322from websockets .exceptions import ConnectionClosed
@@ -403,7 +412,10 @@ def run(self,
403412 single_model = False ,
404413 max_clients = 4 ,
405414 max_connection_time = 600 ,
406- cache_path = "~/.cache/whisper-live/" ):
415+ cache_path = "~/.cache/whisper-live/" ,
416+ rest_port = 8000 ,
417+ enable_rest = False ,
418+ cors_origins : Optional [str ] = None ):
407419 """
408420 Run the transcription server.
409421
@@ -427,6 +439,122 @@ def run(self,
427439 logging .info ("Single model mode currently only works with custom models." )
428440 if not BackendType .is_valid (backend ):
429441 raise ValueError (f"{ backend } is not a valid backend type. Choose backend from { BackendType .valid_types ()} " )
442+
443+ # New OpenAI-compatible REST API (toggleable via enable_rest boolean)
444+ if enable_rest :
445+ app = FastAPI (title = "WhisperLive OpenAI-Compatible API" )
446+ origins = [o .strip () for o in cors_origins .split (',' )] if cors_origins else []
447+ app .add_middleware (
448+ CORSMiddleware ,
449+ allow_origins = origins ,
450+ allow_credentials = True ,
451+ allow_methods = ["*" ], # Allows all methods (GET, POST, etc.)
452+ allow_headers = ["*" ], # Allows all headers
453+ )
454+
455+
456+ @app .post ("/v1/audio/transcriptions" )
457+ async def transcribe (
458+ file : UploadFile ,
459+ model : str = Form (default = "whisper-1" ),
460+ language : Optional [str ] = Form (default = None ),
461+ prompt : Optional [str ] = Form (default = None ),
462+ response_format : str = Form (default = "json" ),
463+ temperature : float = Form (default = 0.0 ),
464+ timestamp_granularities : Optional [List [str ]] = Form (default = None ),
465+ # Stubs for unsupported OpenAI params
466+ chunking_strategy : Optional [str ] = Form (default = None ),
467+ include : Optional [List [str ]] = Form (default = None ),
468+ known_speaker_names : Optional [List [str ]] = Form (default = None ),
469+ known_speaker_references : Optional [List [str ]] = Form (default = None ),
470+ stream : bool = Form (default = False )
471+ ):
472+ if stream :
473+ return JSONResponse ({"error" : "Streaming not supported in this backend." }, status_code = 400 )
474+ if chunking_strategy or known_speaker_names or known_speaker_references :
475+ logging .warning ("Diarization/chunking params ignored; not supported." )
476+
477+ supported_formats = ["json" , "text" , "srt" , "verbose_json" , "vtt" ]
478+ if response_format not in supported_formats :
479+ return JSONResponse ({"error" : f"Unsupported response_format. Supported: { supported_formats } " }, status_code = 400 )
480+
481+ if model != "whisper-1" :
482+ logging .warning (f"Model '{ model } ' requested; using 'small' as fallback." )
483+ model_name = faster_whisper_custom_model_path or "small"
484+
485+ try :
486+ suffix = os .path .splitext (file .filename )[1 ] or ".wav"
487+ with tempfile .NamedTemporaryFile (delete = False , suffix = suffix ) as tmp :
488+ shutil .copyfileobj (file .file , tmp )
489+ tmp_path = tmp .name
490+
491+ device = "cuda" if torch .cuda .is_available () else "cpu"
492+ compute_type = "float16" if device == "cuda" else "int8"
493+
494+ transcriber = WhisperModel (model_name , device = device , compute_type = compute_type )
495+ segments , info = transcriber .transcribe (
496+ tmp_path ,
497+ language = language ,
498+ initial_prompt = prompt ,
499+ temperature = temperature ,
500+ vad_filter = False ,
501+ word_timestamps = (timestamp_granularities and "word" in timestamp_granularities )
502+ )
503+
504+ text = " " .join ([s .text .strip () for s in segments ])
505+ os .unlink (tmp_path )
506+
507+ if response_format == "text" :
508+ return PlainTextResponse (text )
509+ elif response_format == "json" :
510+ return {"text" : text }
511+ elif response_format == "verbose_json" :
512+ verbose = {
513+ "task" : "transcribe" ,
514+ "language" : info .language ,
515+ "duration" : info .duration ,
516+ "text" : text ,
517+ "segments" : []
518+ }
519+ for seg in segments :
520+ seg_dict = {
521+ "id" : seg .id ,
522+ "seek" : seg .seek ,
523+ "start" : seg .start ,
524+ "end" : seg .end ,
525+ "text" : seg .text .strip (),
526+ "tokens" : seg .tokens ,
527+ "temperature" : seg .temperature ,
528+ "avg_logprob" : seg .avg_logprob ,
529+ "compression_ratio" : seg .compression_ratio ,
530+ "no_speech_prob" : seg .no_speech_prob
531+ }
532+ if timestamp_granularities and "word" in timestamp_granularities :
533+ seg_dict ["words" ] = [{"word" : w .word , "start" : w .start , "end" : w .end , "probability" : w .probability } for w in seg .words ]
534+ verbose ["segments" ].append (seg_dict )
535+ return verbose
536+ elif response_format in ["srt" , "vtt" ]:
537+ output = []
538+ for i , seg in enumerate (segments , 1 ):
539+ start = f"{ int (seg .start // 3600 ):02} :{ int ((seg .start % 3600 ) // 60 ):02} :{ seg .start % 60 :06.3f} "
540+ end = f"{ int (seg .end // 3600 ):02} :{ int ((seg .end % 3600 ) // 60 ):02} :{ seg .end % 60 :06.3f} "
541+ if response_format == "srt" :
542+ output .append (f"{ i } \n { start .replace ('.' , ',' )} --> { end .replace ('.' , ',' )} \n { seg .text .strip ()} \n " )
543+ else : # vtt
544+ output .append (f"{ start } --> { end } \n { seg .text .strip ()} \n " )
545+ return PlainTextResponse ("\n " .join (output ))
546+ except Exception as e :
547+ return JSONResponse ({"error" : str (e )}, status_code = 500 )
548+
549+ threading .Thread (
550+ target = uvicorn .run ,
551+ args = (app ,),
552+ kwargs = {"host" : "0.0.0.0" , "port" : rest_port , "log_level" : "info" },
553+ daemon = True
554+ ).start ()
555+ logging .info (f"✅ OpenAI-Compatible API started on http://0.0.0.0:{ rest_port } " )
556+
557+ # Original WebSocket server (always supported)
430558 with serve (
431559 functools .partial (
432560 self .recv_audio ,
@@ -486,5 +614,4 @@ def cleanup(self, websocket):
486614 # Wait for translation thread to finish
487615 if hasattr (client , 'translation_thread' ) and client .translation_thread :
488616 client .translation_thread .join (timeout = 2.0 )
489- self .client_manager .remove_client (websocket )
490-
617+ self .client_manager .remove_client (websocket )
0 commit comments