diff --git a/examples/inference-server/Dockerfile b/examples/inference-server/Dockerfile new file mode 100644 index 0000000..f9cee84 --- /dev/null +++ b/examples/inference-server/Dockerfile @@ -0,0 +1,65 @@ +# Dockerfile of qwenllm/qwen3-asr:cu128 + +ARG CUDA_VERSION=12.8.0 +ARG from=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 +FROM ${from} AS base + +ARG DEBIAN_FRONTEND=noninteractive +RUN < 1 else self.latencies[0]:.3f}") + + if self.rtfs: + print("\n-- RTF --") + print(f"Avg: {statistics.mean(self.rtfs):.4f}") + print(f"P50: {statistics.median(self.rtfs):.4f}") + print(f"P95: {statistics.quantiles(self.rtfs, n=20)[-1] if len(self.rtfs) > 1 else self.rtfs[0]:.4f}") + +async def run_command(cmd): + t0 = time.time() + proc = await create_subprocess_exec( + *cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = await proc.communicate() + t1 = time.time() + return t1 - t0, stdout.decode(), stderr.decode(), proc.returncode + +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=["streaming", "batch"], required=True) + parser.add_argument("--url", required=True) + parser.add_argument("--file", required=True) + parser.add_argument("--clients", type=int, default=1) + parser.add_argument("--requests", type=int, default=1) + args = parser.parse_args() + + print(f"Starting Benchmark: {args.clients} clients, {args.requests} requests each.") + + stats = BenchmarkStats() + stats.start_time = time.time() + + # Semaphore to limit concurrency if needed, but here clients=concurrency + sem = asyncio.Semaphore(args.clients) + + async def worker(): + async with sem: + for _ in range(args.requests): + if args.mode == "streaming": + cmd = ["python", "client-streaming.py", "-e", args.url, "-f", args.file] + dur, out, err, rc = await run_command(cmd) + if rc == 0: + stats.add_streaming_result(out) + else: + print(f"Error: {err}") + stats.errors += 1 + else: + # Batch uses curl + # curl -s -X POST url -F files=@file + cmd = ["curl", "-s", "-X", "POST", args.url, "-F", f"files=@{args.file}"] + dur, out, err, rc = await run_command(cmd) + if rc == 0 and "Error" not in out: # Simple check + stats.add_batch_result(dur, out, rc) + else: + stats.errors += 1 + + tasks = [worker() for _ in range(args.clients)] + await asyncio.gather(*tasks) + + stats.end_time = time.time() + stats.report() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/inference-server/client-streaming.js b/examples/inference-server/client-streaming.js new file mode 100644 index 0000000..48d69d1 --- /dev/null +++ b/examples/inference-server/client-streaming.js @@ -0,0 +1,95 @@ +/* Client for streaming ASR - this is for TESTING PURPOSES */ +import WebSocket from "ws"; +import fs from "fs"; + +const args = process.argv.slice(2); +let endpoint = ""; +let filePath = ""; + +for (let i = 0; i < args.length; i++) { + if (args[i] === "-e" && args[i + 1]) { + endpoint = args[i + 1]; + i++; + } else if (args[i] === "-f" && args[i + 1]) { + filePath = args[i + 1]; + i++; + } +} + +if (!endpoint || !filePath) { + console.error("Usage: node client-streaming.js -e -f "); + process.exit(1); +} + +const CHUNK_BYTES = 640; +const ws = new WebSocket(endpoint); + +console.log(`Connecting to ${endpoint}...`); + +const stats = fs.statSync(filePath); +const duration = stats.size / 32000.0; +console.log(`Audio Duration: ${duration.toFixed(2)}s`); + +const startTime = Date.now(); + +ws.on("open", () => { + console.log("Connected."); + ws.send(JSON.stringify({ + type: "start", + format: "pcm_s16le", + sample_rate_hz: 16000, + channels: 1 + })); + + try { + const fd = fs.openSync(filePath, "r"); + const buf = Buffer.alloc(CHUNK_BYTES); + + const sendChunk = () => { + const n = fs.readSync(fd, buf, 0, CHUNK_BYTES, null); + if (n > 0) { + // Must copy buffer because ws.send is async and we reuse buf immediately + ws.send(Buffer.from(buf.subarray(0, n))); + setImmediate(sendChunk); + } else { + fs.closeSync(fd); + ws.send(JSON.stringify({ type: "stop" })); + console.log("Finished sending audio."); + } + }; + + sendChunk(); + } catch (err) { + console.error(`Error reading file: ${err.message}`); + ws.close(); + } +}); + +ws.on("message", (data) => { + try { + const timestamp = new Date().toISOString().split('T')[1].slice(0, -1); + const evt = JSON.parse(data.toString()); + if (evt.type === 'partial') { + process.stdout.write(`\r[${timestamp}] [Partial] ${evt.text}`); + } else if (evt.type === 'final') { + console.log(`\n[${timestamp}] [Final] ${evt.text}`); + } else { + console.log(`\n[${timestamp}] [${evt.type}] ${JSON.stringify(evt)}`); + } + } catch { + console.log("\n[Non-JSON]", data.toString()); + } +}); + +ws.on("close", () => { + const endTime = Date.now(); + const processTime = (endTime - startTime) / 1000.0; + const rtf = processTime / duration; + console.log(`\nProcessing Time: ${processTime.toFixed(2)}s`); + console.log(`Real-Time Factor (RTF): ${rtf.toFixed(4)}`); + console.log("\nDisconnected."); +}); + +ws.on("error", (err) => { + console.error(`WebSocket error: ${err.message}`); +}); diff --git a/examples/inference-server/client-streaming.py b/examples/inference-server/client-streaming.py new file mode 100644 index 0000000..bbe6e6b --- /dev/null +++ b/examples/inference-server/client-streaming.py @@ -0,0 +1,88 @@ +# Client for streaming ASR - this is for TESTING PURPOSES + +import asyncio +import json +import argparse +import sys +import websockets +from datetime import datetime + +import os +import time + +CHUNK_BYTES = 640 # 20ms at 16kHz 16-bit mono + +async def sender(ws, pcm_path: str): + # Handshake / Config + await ws.send(json.dumps({ + "type": "start", + "format": "pcm_s16le", + "sample_rate_hz": 16000, + "channels": 1 + })) + + print(f"Streaming {pcm_path}...") + with open(pcm_path, "rb") as f: + while True: + chunk = f.read(CHUNK_BYTES) + if not chunk: + break + await ws.send(chunk) + await asyncio.sleep(0) # Yield control to ensure receiver can process messages + + await ws.send(json.dumps({"type": "stop"})) + print("Finished sending audio.") + +async def receiver(ws): + async for message in ws: + try: + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + evt = json.loads(message) + msg_type = evt.get('type') + text = evt.get('text', '') + lang = evt.get('language', '') + + if msg_type == 'ready': + print(f"[{timestamp}] [Server Ready]") + elif msg_type == 'partial': + # Overwrite line for partial updates to keep clean output + sys.stdout.write(f"\r[{timestamp}] [Partial] ({lang}): {text}") + sys.stdout.flush() + elif msg_type == 'final': + print(f"\n[{timestamp}] [Final] ({lang}): {text}") + elif msg_type == 'error': + print(f"\n[{timestamp}] [Error]: {evt.get('message')}") + else: + print(f"\n[{timestamp}] [Unknown]: {evt}") + + except json.JSONDecodeError: + print(f"\n[Raw]: {message}") + +async def main(): + parser = argparse.ArgumentParser(description="Qwen3-ASR Streaming Client") + parser.add_argument("-e", "--endpoint", required=True, help="WebSocket Endpoint URL (e.g. ws://localhost:8907/transcribe-streaming)") + parser.add_argument("-f", "--file", required=True, help="Path to raw PCM 16k 16-bit mono file (or WAV with correct format)") + args = parser.parse_args() + + print(f"Connecting to {args.endpoint}...") + + file_size = os.path.getsize(args.file) + duration = file_size / 32000.0 # 16000 * 2 bytes + print(f"Audio Duration: {duration:.2f}s") + + start_time = time.time() + try: + async with websockets.connect(args.endpoint, max_size=None) as ws: + await asyncio.gather(sender(ws, args.file), receiver(ws)) + + end_time = time.time() + process_time = end_time - start_time + rtf = process_time / duration + print(f"\nProcessing Time: {process_time:.2f}s") + print(f"Real-Time Factor (RTF): {rtf:.4f}") + + except Exception as e: + print(f"Connection failed: {e}") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/inference-server/files/reference.m4a b/examples/inference-server/files/reference.m4a new file mode 100644 index 0000000..1ef00a8 Binary files /dev/null and b/examples/inference-server/files/reference.m4a differ diff --git a/examples/inference-server/files/reference.mp3 b/examples/inference-server/files/reference.mp3 new file mode 100644 index 0000000..d380e03 Binary files /dev/null and b/examples/inference-server/files/reference.mp3 differ diff --git a/examples/inference-server/files/reference.pcm b/examples/inference-server/files/reference.pcm new file mode 100644 index 0000000..45fcab8 Binary files /dev/null and b/examples/inference-server/files/reference.pcm differ diff --git a/examples/inference-server/files/reference.wav b/examples/inference-server/files/reference.wav new file mode 100644 index 0000000..c8403b0 Binary files /dev/null and b/examples/inference-server/files/reference.wav differ diff --git a/examples/inference-server/package-lock.json b/examples/inference-server/package-lock.json new file mode 100644 index 0000000..8da3955 --- /dev/null +++ b/examples/inference-server/package-lock.json @@ -0,0 +1,36 @@ +{ + "name": "qwen-asr-client", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "qwen-asr-client", + "version": "1.0.0", + "dependencies": { + "ws": "^8.16.0" + } + }, + "node_modules/ws": { + "version": "8.19.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz", + "integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + } + } +} diff --git a/examples/inference-server/package.json b/examples/inference-server/package.json new file mode 100644 index 0000000..50206df --- /dev/null +++ b/examples/inference-server/package.json @@ -0,0 +1,9 @@ +{ + "name": "qwen-asr-inference-server-client", + "version": "1.0.0", + "description": "Streaming client for Qwen-ASR example inference server", + "type": "module", + "dependencies": { + "ws": "^8.16.0" + } +} diff --git a/examples/inference-server/requirements.txt b/examples/inference-server/requirements.txt new file mode 100644 index 0000000..3df1e7b --- /dev/null +++ b/examples/inference-server/requirements.txt @@ -0,0 +1,2 @@ +websockets +aiohttp diff --git a/examples/inference-server/server.py b/examples/inference-server/server.py new file mode 100644 index 0000000..18c62d4 --- /dev/null +++ b/examples/inference-server/server.py @@ -0,0 +1,469 @@ +import os +import json +import io +import asyncio +import logging +import subprocess +from typing import Optional, List, Tuple +from contextlib import asynccontextmanager +from concurrent.futures import ThreadPoolExecutor +import time + +import uvicorn +import numpy as np +import soundfile as sf +import torch +import psutil +from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect, Query, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +# Import Qwen-ASR components +try: + from qwen_asr import Qwen3ASRModel, Qwen3ForcedAligner +except ImportError: + print("Warning: qwen_asr not found.") + Qwen3ASRModel = None + Qwen3ForcedAligner = None + +# Logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# ----------------------------- +# Config +# ----------------------------- +def get_env_bool(key: str, default: str = "true") -> bool: + return os.getenv(key, default).lower() in ("true", "1", "yes", "on") + +MAX_CONCURRENT_DECODE = int(os.getenv("MAX_CONCURRENT_DECODE", "4")) +MAX_CONCURRENT_INFER = int(os.getenv("MAX_CONCURRENT_INFER", "1")) # GPU: usually 1 +THREADPOOL_WORKERS = int(os.getenv("THREADPOOL_WORKERS", str((os.cpu_count() or 4) * 5))) + +# Streaming buffering/throttling +STREAM_MIN_SAMPLES = int(os.getenv("STREAM_MIN_SAMPLES", "1600")) # 100ms @ 16kHz +PARTIAL_INTERVAL_MS = int(os.getenv("PARTIAL_INTERVAL_MS", "120")) # throttle partials +STREAM_EXPECT_SR = int(os.getenv("STREAM_EXPECT_SR", "16000")) + +# ----------------------------- +# App state +# ----------------------------- +models = {} +model_status = "starting" +model_ready_event = asyncio.Event() + +decode_sem = asyncio.Semaphore(MAX_CONCURRENT_DECODE) +infer_sem = asyncio.Semaphore(MAX_CONCURRENT_INFER) + +# ----------------------------- +# Helpers +# ----------------------------- +async def to_thread_limited(sem: asyncio.Semaphore, fn, *args, **kwargs): + async with sem: + return await asyncio.to_thread(fn, *args, **kwargs) + +def map_language(lang_code: Optional[str]) -> Optional[str]: + """Map ISO code to Qwen full name.""" + if lang_code is None: + return None + mapping = { + "en": "English", "de": "German", "fr": "French", "es": "Spanish", + "it": "Italian", "ja": "Japanese", "ko": "Korean", "zh": "Chinese", + "ru": "Russian", "pt": "Portuguese", "nl": "Dutch", "tr": "Turkish", + "sv": "Swedish", "id": "Indonesian", "vi": "Vietnamese", + "hi": "Hindi", "ar": "Arabic", + } + return mapping.get(lang_code.lower(), lang_code) + +def read_audio_file(file_bytes: bytes) -> Tuple[np.ndarray, int]: + """ + Sync decode. Must be called via asyncio.to_thread (or threadpool). + soundfile first; fallback to ffmpeg for mp3/m4a/etc. + """ + try: + with io.BytesIO(file_bytes) as f: + wav, sr = sf.read(f, dtype="float32", always_2d=False) + return wav, sr + except Exception: + process = subprocess.Popen( + ["ffmpeg", "-i", "pipe:0", "-f", "wav", "-"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, err = process.communicate(input=file_bytes) + if process.returncode != 0: + raise ValueError(f"FFmpeg decoding failed: {err.decode(errors='ignore')}") + with io.BytesIO(out) as f: + wav, sr = sf.read(f, dtype="float32", always_2d=False) + return wav, sr + +# ----------------------------- +# Model loading +# ----------------------------- +async def load_models_background(): + global model_status + logger.info("Background task: Loading models...") + model_status = "loading_models" + + async def _load_asr(): + global model_status + if not get_env_bool("ENABLE_ASR_MODEL", "true"): + logger.info("ASR Model disabled via ENABLE_ASR_MODEL.") + return + if Qwen3ASRModel is None: + raise RuntimeError("qwen_asr not installed (Qwen3ASRModel missing).") + + model_name = os.getenv("ASR_MODEL_NAME", "Qwen/Qwen3-ASR-1.7B") + logger.info(f"Loading ASR Model: {model_name}...") + gpu_mem = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.75")) + max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096")) + + try: + models["asr"] = await asyncio.to_thread( + Qwen3ASRModel.LLM, + model=model_name, + gpu_memory_utilization=gpu_mem, + max_new_tokens=max_new_tokens, + ) + logger.info("ASR Model loaded successfully.") + except Exception as e: + logger.exception(f"Failed to load ASR model: {e}") + model_status = "error" + raise + + async def _load_aligner(): + global model_status + if not get_env_bool("ENABLE_ALIGNER_MODEL", "true"): + logger.info("Aligner Model disabled via ENABLE_ALIGNER_MODEL.") + return + if Qwen3ForcedAligner is None: + raise RuntimeError("qwen_asr not installed (Qwen3ForcedAligner missing).") + + aligner_name = os.getenv("ALIGNER_MODEL_NAME", "Qwen/Qwen3-ForcedAligner-0.6B") + logger.info(f"Loading Aligner Model: {aligner_name}...") + + try: + models["aligner"] = await asyncio.to_thread( + Qwen3ForcedAligner.from_pretrained, + aligner_name, + dtype=torch.bfloat16, + device_map="cuda:0", + ) + logger.info("Aligner Model loaded successfully.") + except Exception as e: + logger.exception(f"Failed to load Aligner model: {e}") + model_status = "error" + raise + + try: + await asyncio.gather(_load_asr(), _load_aligner()) + except Exception: + # model_status already set to "error" by loaders + model_ready_event.set() # don't hang endpoints + return + + # Warmup (best-effort) + if "asr" in models: + logger.info("Warming up ASR model (best-effort)...") + model_status = "warming_up" + try: + dummy_wav = np.zeros(16000, dtype=np.float32) + dummy_sr = 16000 + + async with infer_sem: + await asyncio.to_thread( + models["asr"].transcribe, + audio=[(dummy_wav, dummy_sr)], + language=["English"], + return_time_stamps=False, + ) + + async with infer_sem: + state = await asyncio.to_thread( + models["asr"].init_streaming_state, + unfixed_chunk_num=2, + unfixed_token_num=5, + chunk_size_sec=2.0, + ) + + warmup_chunks = [320, 640, 1024, 3200] + [3200] * 25 + for n in warmup_chunks: + async with infer_sem: + await asyncio.to_thread(models["asr"].streaming_transcribe, dummy_wav[:n], state) + + async with infer_sem: + await asyncio.to_thread(models["asr"].finish_streaming_transcribe, state) + + logger.info("Warmup complete.") + except Exception as e: + logger.warning(f"Warmup failed (non-critical): {e}") + + model_status = "ready" + model_ready_event.set() + logger.info("Server is ready to accept requests.") + +# ----------------------------- +# Lifespan +# ----------------------------- +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting up Qwen3-ASR Server...") + + # Bigger threadpool helps when decoding + websocket buffering + other to_thread calls happen together. + executor = ThreadPoolExecutor(max_workers=THREADPOOL_WORKERS) + app.state.executor = executor + asyncio.get_running_loop().set_default_executor(executor) + + task = asyncio.create_task(load_models_background()) + try: + yield + finally: + # Shutdown + task.cancel() + models.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + executor.shutdown(wait=False, cancel_futures=True) + logger.info("Shutdown complete.") + +# ----------------------------- +# App +# ----------------------------- +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ----------------------------- +# Endpoints +# ----------------------------- +@app.get("/health") +async def health(): + mem = psutil.virtual_memory() + info = { + "status": model_status, + "limits": { + "max_concurrent_decode": MAX_CONCURRENT_DECODE, + "max_concurrent_infer": MAX_CONCURRENT_INFER, + "threadpool_workers": THREADPOOL_WORKERS, + }, + "memory": { + "ram_total_mb": mem.total // (1024 * 1024), + "ram_available_mb": mem.available // (1024 * 1024), + "ram_percent": mem.percent, + }, + } + if torch.cuda.is_available(): + info["memory"]["gpu_allocated_mb"] = torch.cuda.memory_allocated() // (1024 * 1024) + info["memory"]["gpu_reserved_mb"] = torch.cuda.memory_reserved() // (1024 * 1024) + return info + +@app.post("/transcribe") +async def transcribe( + files: List[UploadFile] = File(...), + language: Optional[str] = Query(None, description="Language code (e.g. en, de, fr). None for auto-detect."), + forced_alignment: bool = Query(False, description="Enable forced alignment (timestamps)"), +): + await model_ready_event.wait() + + if model_status != "ready": + raise HTTPException(status_code=503, detail=f"Server not ready: {model_status}") + if "asr" not in models: + raise HTTPException(status_code=503, detail="ASR model is not enabled or failed to load.") + + full_lang = map_language(language) + + async def decode_one(f: UploadFile): + content = await f.read() + return await to_thread_limited(decode_sem, read_audio_file, content) + + # Decode concurrently (limited) + try: + audio_batch = await asyncio.gather(*(decode_one(f) for f in files)) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid audio file: {e}") + + # Inference (explicitly limited, because GPU concurrency is not free) + try: + async with infer_sem: + results = await asyncio.to_thread( + models["asr"].transcribe, + audio=audio_batch, + language=[full_lang] * len(audio_batch), + return_time_stamps=False, + ) + + response_list = [] + + if forced_alignment: + if "aligner" not in models: + raise HTTPException(status_code=503, detail="Aligner model is not enabled or failed to load.") + + texts = [r.text for r in results] + + async with infer_sem: + alignment_results = await asyncio.to_thread( + models["aligner"].align, + audio=audio_batch, + text=texts, + language=[full_lang] * len(audio_batch), + ) + + for i, res in enumerate(results): + response_list.append( + {"text": res.text, "language": res.language, "timestamps": alignment_results[i]} + ) + else: + for res in results: + response_list.append({"text": res.text, "language": res.language}) + + return response_list + + except HTTPException: + raise + except Exception as e: + logger.exception(f"Inference failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.websocket("/transcribe-streaming") +async def websocket_endpoint( + ws: WebSocket, + language: Optional[str] = Query(None), + forced_alignment: bool = Query(False), # kept for API symmetry; not yet used in streaming +): + await ws.accept() + + # do wait until we know the outcome + await model_ready_event.wait() + + if model_status != "ready" or "asr" not in models: + await ws.close(code=1011, reason=f"Server not ready: {model_status}") + return + + full_lang = map_language(language) + client_sr = None + started = False + + # Init streaming state off event loop + limited concurrency (GPU touch) + try: + async with infer_sem: + state = await asyncio.to_thread( + models["asr"].init_streaming_state, + unfixed_chunk_num=2, + unfixed_token_num=5, + chunk_size_sec=2.0, + ) + except Exception as e: + logger.exception(f"Failed to init streaming state: {e}") + await ws.close(code=1011, reason="init_streaming_state failed") + return + + # Send ready + try: + await ws.send_json({"type": "ready"}) + except Exception: + return + + buf_parts: List[np.ndarray] = [] + buf_n = 0 + last_partial_ts = 0.0 + + async def flush_and_infer(send_partial: bool): + nonlocal buf_parts, buf_n, last_partial_ts + if buf_n <= 0: + return + chunk = np.concatenate(buf_parts, axis=0) if len(buf_parts) > 1 else buf_parts[0] + + async with infer_sem: + await asyncio.to_thread(models["asr"].streaming_transcribe, chunk, state) + + if send_partial: + now = time.monotonic() + if (now - last_partial_ts) * 1000.0 >= PARTIAL_INTERVAL_MS: + await ws.send_json({"type": "partial", "text": state.text, "language": state.language}) + + try: + while True: + msg = await ws.receive() + + if msg["type"] == "websocket.disconnect": + break + + if msg["type"] != "websocket.receive": + continue + + # Control messages + if msg.get("text"): + try: + data = json.loads(msg["text"]) + except json.JSONDecodeError: + data = None + + if isinstance(data, dict): + t = data.get("type") + + if t == "start": + started = True + client_sr = int(data.get("sample_rate_hz", 0)) if data.get("sample_rate_hz") else None + fmt = data.get("format") + + if client_sr != STREAM_EXPECT_SR or fmt not in (None, "pcm_s16le"): + await ws.send_json( + {"type": "error", "message": f"Only pcm_s16le @ {STREAM_EXPECT_SR}Hz supported"} + ) + await ws.close(code=1003) + return + + # Optional: acknowledge language selection + if full_lang is not None: + await ws.send_json({"type": "info", "message": f"language={full_lang}"}) + continue + + if t == "stop": + # Flush remainder, finish, send final + await flush_and_infer(send_partial=False) + async with infer_sem: + await asyncio.to_thread(models["asr"].finish_streaming_transcribe, state) + + await ws.send_json({"type": "final", "text": state.text, "language": state.language}) + await ws.close(code=1000) + return + + # Audio frames + if msg.get("bytes"): + if not started: + # Require explicit start so we can validate format. + await ws.send_json({"type": "error", "message": "Send {type:'start', format:'pcm_s16le', sample_rate_hz:16000} first"}) + await ws.close(code=1002) + return + + chunk_bytes = msg["bytes"] + # int16 mono little-endian -> float32 [-1, 1] + audio_int16 = np.frombuffer(chunk_bytes, dtype=np.int16) + if audio_int16.size == 0: + continue + + audio_f32 = audio_int16.astype(np.float32) / 32768.0 + buf_parts.append(audio_f32) + buf_n += audio_f32.size + + if buf_n >= STREAM_MIN_SAMPLES: + await flush_and_infer(send_partial=True) + + except WebSocketDisconnect: + pass + except Exception as e: + logger.exception(f"WS Error: {e}") + try: + await ws.close(code=1011, reason="internal error") + except Exception: + pass + +if __name__ == "__main__": + # NOTE: for GPU models, keep workers=1 unless you deliberately replicate the model per worker. + uvicorn.run(app, host="0.0.0.0", port=8000)