Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions mistral/voxtral-streaming-4b/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
model_name: Voxtral-Mini-4B-Realtime-2602
secrets:
hf_access_token: null
environment_variables:
VLLM_DISABLE_COMPILE_CACHE: "1"
base_image:
image: vllm/vllm-openai:nightly-d88a1df699f68e5284fe3a3170f8ae292a3e9c3f
docker_server:
start_command: sh -c "HF_TOKEN=$(cat /secrets/hf_access_token) VLLM_DISABLE_COMPILE_CACHE=1 vllm serve mistralai/Voxtral-Mini-4B-Realtime-2602 --compilation-config '{\"cudagraph_mode\":\"PIECEWISE\"}' --host 0.0.0.0 --port 8000"
readiness_endpoint: /health
liveness_endpoint: /health
predict_endpoint: /v1/realtime
server_port: 8000
resources:
accelerator: H100_40GB:1
cpu: "1"
memory: 10Gi
use_gpu: true
requirements:
- --pre --extra-index-url https://wheels.vllm.ai/nightly
- vllm[audio]
- librosa
- torch
- torchaudio
- pynvml
- ffmpeg-python
- websockets
system_packages:
- python3.10-venv
- ffmpeg
- openmpi-bin
- libopenmpi-dev
runtime:
is_websocket_endpoint: true
transport:
kind: websocket
ping_interval_seconds: null
ping_timeout_seconds: null
164 changes: 164 additions & 0 deletions mistral/voxtral-streaming-4b/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import asyncio
import base64
import json
import signal

import numpy as np
import sounddevice as sd
import websockets

SAMPLE_RATE = 16_000
CHUNK_MS = 100 # send 100ms chunks
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)

model_id = "" # Place model id here
BASETEN_API_KEY = "" # Baseten API key here

WS_URL = f"wss://model-{model_id}.api.baseten.co/environments/production/websocket"
MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602"

WARMUP_SECONDS = 2.0 # optional
SEND_COMMIT_EVERY_N_CHUNKS = 10 # optional: commit about once per second


def pcm16_to_b64(pcm16: np.ndarray) -> str:
return base64.b64encode(pcm16.tobytes()).decode("utf-8")


async def send_warmup_silence(ws):
"""Send a little silence so the server/model warms up (optional)."""
total = int(SAMPLE_RATE * WARMUP_SECONDS)
silence = np.zeros(total, dtype=np.int16)

for i in range(0, total, CHUNK_SAMPLES):
chunk = silence[i : i + CHUNK_SAMPLES]
await ws.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": pcm16_to_b64(chunk),
}
)
)
await asyncio.sleep(CHUNK_MS / 1000)


async def microphone_producer(audio_q: asyncio.Queue):
"""
Capture mic audio and push PCM16 chunks into an asyncio.Queue.
sounddevice callback runs on a separate thread; we hop into asyncio thread safely.
"""
loop = asyncio.get_running_loop()

def callback(indata, frames, time_info, status):
if status:
# non-fatal stream warnings
pass
# indata is float32 in [-1, 1], shape (frames, channels)
mono = indata[:, 0]
pcm16 = (np.clip(mono, -1.0, 1.0) * 32767.0).astype(np.int16)
loop.call_soon_threadsafe(audio_q.put_nowait, pcm16)

stream = sd.InputStream(
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
blocksize=CHUNK_SAMPLES,
callback=callback,
)

with stream:
# run until cancelled
while True:
await asyncio.sleep(0.1)


async def send_audio(ws, audio_q: asyncio.Queue, stop_event: asyncio.Event):
"""Pull mic chunks from queue and send to websocket."""
n = 0
while not stop_event.is_set():
try:
pcm16 = await asyncio.wait_for(audio_q.get(), timeout=0.5)
except asyncio.TimeoutError:
continue

await ws.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": pcm16_to_b64(pcm16),
}
)
)

n += 1
if n % SEND_COMMIT_EVERY_N_CHUNKS == 0:
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))


async def receive_text(ws, stop_event: asyncio.Event):
"""Print transcription deltas as they arrive."""
async for msg in ws:
if stop_event.is_set():
break

try:
data = json.loads(msg)
except json.JSONDecodeError:
continue

if data.get("type") == "transcription.delta":
delta = data.get("delta", "")
print(delta, end="", flush=True)

# If your server emits other event types you care about, handle them here:
# elif data.get("type") == "...": ...


async def main():
stop_event = asyncio.Event()
audio_q: asyncio.Queue[np.ndarray] = asyncio.Queue(maxsize=50)

def request_stop(*_):
stop_event.set()

# Ctrl+C handling
signal.signal(signal.SIGINT, request_stop)
signal.signal(signal.SIGTERM, request_stop)

async with websockets.connect(
WS_URL, extra_headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"}
) as ws:
# Some servers send an initial "hello"/ack; we can just try to read once (non-fatal if it times out)
try:
_ = await asyncio.wait_for(ws.recv(), timeout=2)
except Exception:
pass

print("[Connection established]")
print("Start speaking 🎙️...")

# Configure session/model
await ws.send(json.dumps({"type": "session.update", "model": MODEL}))

# Optional warmup
await send_warmup_silence(ws)
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))

# Start tasks
mic_task = asyncio.create_task(microphone_producer(audio_q))
send_task = asyncio.create_task(send_audio(ws, audio_q, stop_event))
recv_task = asyncio.create_task(receive_text(ws, stop_event))

# Wait for stop (Ctrl+C)
while not stop_event.is_set():
await asyncio.sleep(0.1)

# Cleanup
for t in (mic_task, send_task, recv_task):
t.cancel()
await ws.close()


if __name__ == "__main__":
asyncio.run(main())
Loading