|
| 1 | +import asyncio |
| 2 | +import base64 |
| 3 | +import json |
| 4 | +import signal |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import sounddevice as sd |
| 8 | +import websockets |
| 9 | + |
| 10 | +SAMPLE_RATE = 16_000 |
| 11 | +CHUNK_MS = 100 # send 100ms chunks |
| 12 | +CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000) |
| 13 | + |
| 14 | +model_id = "" # Place model id here |
| 15 | +BASETEN_API_KEY = "" # Baseten API key here |
| 16 | + |
| 17 | +WS_URL = f"wss://model-{model_id}.api.baseten.co/environments/production/websocket" |
| 18 | +MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602" |
| 19 | + |
| 20 | +WARMUP_SECONDS = 2.0 # optional |
| 21 | +SEND_COMMIT_EVERY_N_CHUNKS = 10 # optional: commit about once per second |
| 22 | + |
| 23 | + |
| 24 | +def pcm16_to_b64(pcm16: np.ndarray) -> str: |
| 25 | + return base64.b64encode(pcm16.tobytes()).decode("utf-8") |
| 26 | + |
| 27 | + |
| 28 | +async def send_warmup_silence(ws): |
| 29 | + """Send a little silence so the server/model warms up (optional).""" |
| 30 | + total = int(SAMPLE_RATE * WARMUP_SECONDS) |
| 31 | + silence = np.zeros(total, dtype=np.int16) |
| 32 | + |
| 33 | + for i in range(0, total, CHUNK_SAMPLES): |
| 34 | + chunk = silence[i : i + CHUNK_SAMPLES] |
| 35 | + await ws.send( |
| 36 | + json.dumps( |
| 37 | + { |
| 38 | + "type": "input_audio_buffer.append", |
| 39 | + "audio": pcm16_to_b64(chunk), |
| 40 | + } |
| 41 | + ) |
| 42 | + ) |
| 43 | + await asyncio.sleep(CHUNK_MS / 1000) |
| 44 | + |
| 45 | + |
| 46 | +async def microphone_producer(audio_q: asyncio.Queue): |
| 47 | + """ |
| 48 | + Capture mic audio and push PCM16 chunks into an asyncio.Queue. |
| 49 | + sounddevice callback runs on a separate thread; we hop into asyncio thread safely. |
| 50 | + """ |
| 51 | + loop = asyncio.get_running_loop() |
| 52 | + |
| 53 | + def callback(indata, frames, time_info, status): |
| 54 | + if status: |
| 55 | + # non-fatal stream warnings |
| 56 | + pass |
| 57 | + # indata is float32 in [-1, 1], shape (frames, channels) |
| 58 | + mono = indata[:, 0] |
| 59 | + pcm16 = (np.clip(mono, -1.0, 1.0) * 32767.0).astype(np.int16) |
| 60 | + loop.call_soon_threadsafe(audio_q.put_nowait, pcm16) |
| 61 | + |
| 62 | + stream = sd.InputStream( |
| 63 | + samplerate=SAMPLE_RATE, |
| 64 | + channels=1, |
| 65 | + dtype="float32", |
| 66 | + blocksize=CHUNK_SAMPLES, |
| 67 | + callback=callback, |
| 68 | + ) |
| 69 | + |
| 70 | + with stream: |
| 71 | + # run until cancelled |
| 72 | + while True: |
| 73 | + await asyncio.sleep(0.1) |
| 74 | + |
| 75 | + |
| 76 | +async def send_audio(ws, audio_q: asyncio.Queue, stop_event: asyncio.Event): |
| 77 | + """Pull mic chunks from queue and send to websocket.""" |
| 78 | + n = 0 |
| 79 | + while not stop_event.is_set(): |
| 80 | + try: |
| 81 | + pcm16 = await asyncio.wait_for(audio_q.get(), timeout=0.5) |
| 82 | + except asyncio.TimeoutError: |
| 83 | + continue |
| 84 | + |
| 85 | + await ws.send( |
| 86 | + json.dumps( |
| 87 | + { |
| 88 | + "type": "input_audio_buffer.append", |
| 89 | + "audio": pcm16_to_b64(pcm16), |
| 90 | + } |
| 91 | + ) |
| 92 | + ) |
| 93 | + |
| 94 | + n += 1 |
| 95 | + if n % SEND_COMMIT_EVERY_N_CHUNKS == 0: |
| 96 | + await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) |
| 97 | + |
| 98 | + |
| 99 | +async def receive_text(ws, stop_event: asyncio.Event): |
| 100 | + """Print transcription deltas as they arrive.""" |
| 101 | + async for msg in ws: |
| 102 | + if stop_event.is_set(): |
| 103 | + break |
| 104 | + |
| 105 | + try: |
| 106 | + data = json.loads(msg) |
| 107 | + except json.JSONDecodeError: |
| 108 | + continue |
| 109 | + |
| 110 | + if data.get("type") == "transcription.delta": |
| 111 | + delta = data.get("delta", "") |
| 112 | + print(delta, end="", flush=True) |
| 113 | + |
| 114 | + # If your server emits other event types you care about, handle them here: |
| 115 | + # elif data.get("type") == "...": ... |
| 116 | + |
| 117 | + |
| 118 | +async def main(): |
| 119 | + stop_event = asyncio.Event() |
| 120 | + audio_q: asyncio.Queue[np.ndarray] = asyncio.Queue(maxsize=50) |
| 121 | + |
| 122 | + def request_stop(*_): |
| 123 | + stop_event.set() |
| 124 | + |
| 125 | + # Ctrl+C handling |
| 126 | + signal.signal(signal.SIGINT, request_stop) |
| 127 | + signal.signal(signal.SIGTERM, request_stop) |
| 128 | + |
| 129 | + async with websockets.connect( |
| 130 | + WS_URL, extra_headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"} |
| 131 | + ) as ws: |
| 132 | + # Some servers send an initial "hello"/ack; we can just try to read once (non-fatal if it times out) |
| 133 | + try: |
| 134 | + _ = await asyncio.wait_for(ws.recv(), timeout=2) |
| 135 | + except Exception: |
| 136 | + pass |
| 137 | + |
| 138 | + print("[Connection established]") |
| 139 | + print("Start speaking 🎙️...") |
| 140 | + |
| 141 | + # Configure session/model |
| 142 | + await ws.send(json.dumps({"type": "session.update", "model": MODEL})) |
| 143 | + |
| 144 | + # Optional warmup |
| 145 | + await send_warmup_silence(ws) |
| 146 | + await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) |
| 147 | + |
| 148 | + # Start tasks |
| 149 | + mic_task = asyncio.create_task(microphone_producer(audio_q)) |
| 150 | + send_task = asyncio.create_task(send_audio(ws, audio_q, stop_event)) |
| 151 | + recv_task = asyncio.create_task(receive_text(ws, stop_event)) |
| 152 | + |
| 153 | + # Wait for stop (Ctrl+C) |
| 154 | + while not stop_event.is_set(): |
| 155 | + await asyncio.sleep(0.1) |
| 156 | + |
| 157 | + # Cleanup |
| 158 | + for t in (mic_task, send_task, recv_task): |
| 159 | + t.cancel() |
| 160 | + await ws.close() |
| 161 | + |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + asyncio.run(main()) |
0 commit comments