Skip to content

Commit b193494

Browse files
authored
Adding voxtral model (#544)
1 parent d743896 commit b193494

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
model_name: Voxtral-Mini-4B-Realtime-2602
2+
secrets:
3+
hf_access_token: null
4+
environment_variables:
5+
VLLM_DISABLE_COMPILE_CACHE: "1"
6+
base_image:
7+
image: vllm/vllm-openai:nightly-d88a1df699f68e5284fe3a3170f8ae292a3e9c3f
8+
docker_server:
9+
start_command: sh -c "HF_TOKEN=$(cat /secrets/hf_access_token) VLLM_DISABLE_COMPILE_CACHE=1 vllm serve mistralai/Voxtral-Mini-4B-Realtime-2602 --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --host 0.0.0.0 --port 8000"
10+
readiness_endpoint: /health
11+
liveness_endpoint: /health
12+
predict_endpoint: /v1/realtime
13+
server_port: 8000
14+
resources:
15+
accelerator: H100_40GB:1
16+
cpu: "1"
17+
memory: 10Gi
18+
use_gpu: true
19+
requirements:
20+
- --pre --extra-index-url https://wheels.vllm.ai/nightly
21+
- vllm[audio]
22+
- librosa
23+
- torch
24+
- torchaudio
25+
- pynvml
26+
- ffmpeg-python
27+
- websockets
28+
system_packages:
29+
- python3.10-venv
30+
- ffmpeg
31+
- openmpi-bin
32+
- libopenmpi-dev
33+
runtime:
34+
is_websocket_endpoint: true
35+
transport:
36+
kind: websocket
37+
ping_interval_seconds: null
38+
ping_timeout_seconds: null
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import signal
5+
6+
import numpy as np
7+
import sounddevice as sd
8+
import websockets
9+
10+
SAMPLE_RATE = 16_000
11+
CHUNK_MS = 100 # send 100ms chunks
12+
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)
13+
14+
model_id = "" # Place model id here
15+
BASETEN_API_KEY = "" # Baseten API key here
16+
17+
WS_URL = f"wss://model-{model_id}.api.baseten.co/environments/production/websocket"
18+
MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"
19+
20+
WARMUP_SECONDS = 2.0 # optional
21+
SEND_COMMIT_EVERY_N_CHUNKS = 10 # optional: commit about once per second
22+
23+
24+
def pcm16_to_b64(pcm16: np.ndarray) -> str:
25+
return base64.b64encode(pcm16.tobytes()).decode("utf-8")
26+
27+
28+
async def send_warmup_silence(ws):
29+
"""Send a little silence so the server/model warms up (optional)."""
30+
total = int(SAMPLE_RATE * WARMUP_SECONDS)
31+
silence = np.zeros(total, dtype=np.int16)
32+
33+
for i in range(0, total, CHUNK_SAMPLES):
34+
chunk = silence[i : i + CHUNK_SAMPLES]
35+
await ws.send(
36+
json.dumps(
37+
{
38+
"type": "input_audio_buffer.append",
39+
"audio": pcm16_to_b64(chunk),
40+
}
41+
)
42+
)
43+
await asyncio.sleep(CHUNK_MS / 1000)
44+
45+
46+
async def microphone_producer(audio_q: asyncio.Queue):
47+
"""
48+
Capture mic audio and push PCM16 chunks into an asyncio.Queue.
49+
sounddevice callback runs on a separate thread; we hop into asyncio thread safely.
50+
"""
51+
loop = asyncio.get_running_loop()
52+
53+
def callback(indata, frames, time_info, status):
54+
if status:
55+
# non-fatal stream warnings
56+
pass
57+
# indata is float32 in [-1, 1], shape (frames, channels)
58+
mono = indata[:, 0]
59+
pcm16 = (np.clip(mono, -1.0, 1.0) * 32767.0).astype(np.int16)
60+
loop.call_soon_threadsafe(audio_q.put_nowait, pcm16)
61+
62+
stream = sd.InputStream(
63+
samplerate=SAMPLE_RATE,
64+
channels=1,
65+
dtype="float32",
66+
blocksize=CHUNK_SAMPLES,
67+
callback=callback,
68+
)
69+
70+
with stream:
71+
# run until cancelled
72+
while True:
73+
await asyncio.sleep(0.1)
74+
75+
76+
async def send_audio(ws, audio_q: asyncio.Queue, stop_event: asyncio.Event):
77+
"""Pull mic chunks from queue and send to websocket."""
78+
n = 0
79+
while not stop_event.is_set():
80+
try:
81+
pcm16 = await asyncio.wait_for(audio_q.get(), timeout=0.5)
82+
except asyncio.TimeoutError:
83+
continue
84+
85+
await ws.send(
86+
json.dumps(
87+
{
88+
"type": "input_audio_buffer.append",
89+
"audio": pcm16_to_b64(pcm16),
90+
}
91+
)
92+
)
93+
94+
n += 1
95+
if n % SEND_COMMIT_EVERY_N_CHUNKS == 0:
96+
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
97+
98+
99+
async def receive_text(ws, stop_event: asyncio.Event):
100+
"""Print transcription deltas as they arrive."""
101+
async for msg in ws:
102+
if stop_event.is_set():
103+
break
104+
105+
try:
106+
data = json.loads(msg)
107+
except json.JSONDecodeError:
108+
continue
109+
110+
if data.get("type") == "transcription.delta":
111+
delta = data.get("delta", "")
112+
print(delta, end="", flush=True)
113+
114+
# If your server emits other event types you care about, handle them here:
115+
# elif data.get("type") == "...": ...
116+
117+
118+
async def main():
119+
stop_event = asyncio.Event()
120+
audio_q: asyncio.Queue[np.ndarray] = asyncio.Queue(maxsize=50)
121+
122+
def request_stop(*_):
123+
stop_event.set()
124+
125+
# Ctrl+C handling
126+
signal.signal(signal.SIGINT, request_stop)
127+
signal.signal(signal.SIGTERM, request_stop)
128+
129+
async with websockets.connect(
130+
WS_URL, extra_headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"}
131+
) as ws:
132+
# Some servers send an initial "hello"/ack; we can just try to read once (non-fatal if it times out)
133+
try:
134+
_ = await asyncio.wait_for(ws.recv(), timeout=2)
135+
except Exception:
136+
pass
137+
138+
print("[Connection established]")
139+
print("Start speaking 🎙️...")
140+
141+
# Configure session/model
142+
await ws.send(json.dumps({"type": "session.update", "model": MODEL}))
143+
144+
# Optional warmup
145+
await send_warmup_silence(ws)
146+
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
147+
148+
# Start tasks
149+
mic_task = asyncio.create_task(microphone_producer(audio_q))
150+
send_task = asyncio.create_task(send_audio(ws, audio_q, stop_event))
151+
recv_task = asyncio.create_task(receive_text(ws, stop_event))
152+
153+
# Wait for stop (Ctrl+C)
154+
while not stop_event.is_set():
155+
await asyncio.sleep(0.1)
156+
157+
# Cleanup
158+
for t in (mic_task, send_task, recv_task):
159+
t.cancel()
160+
await ws.close()
161+
162+
163+
if __name__ == "__main__":
164+
asyncio.run(main())

0 commit comments

Comments
 (0)