Skip to content

Commit 3c95188

Browse files
authored
Merge branch 'main' into feature/irodori-tts
2 parents b243c65 + 0cbb6d8 commit 3c95188

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

mlx_audio/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ async def _stream_transcription(
522522
"""Handle both streaming and non-streaming model inference over WebSocket.
523523
524524
Streaming models (whose generate() accepts a ``stream`` parameter) receive
525-
the numpy array directly and yield token deltas sent as
525+
the audio as an ``mx.array`` and yield token deltas sent as
526526
``{"type": "delta", "delta": "..."}`` messages, followed by a
527527
``{"type": "complete", ...}`` message.
528528
@@ -533,7 +533,7 @@ async def _stream_transcription(
533533

534534
if supports_stream and streaming:
535535
result_iter = stt_model.generate(
536-
audio_array, stream=True, language=language, verbose=False
536+
mx.array(audio_array), stream=True, language=language, verbose=False
537537
)
538538
accumulated = ""
539539
detected_language = language

mlx_audio/tests/test_server.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
from unittest.mock import AsyncMock, MagicMock, patch
44

5+
import mlx.core as mx
56
import numpy as np
67
import pytest
78

@@ -319,20 +320,34 @@ def test_realtime_ws_streaming_structured_chunks(client, mock_model_provider):
319320
assert "Hello" in combined
320321

321322

322-
def test_realtime_ws_numpy_direct_pass(client, mock_model_provider):
323-
"""Streaming models receive numpy arrays directly, not file paths."""
323+
def test_realtime_ws_mx_array_pass(client, mock_model_provider):
324+
"""Streaming models receive mx.array, not file paths."""
324325
gen_fn = _make_streaming_generate(["test"])
325326
_, mock_stt_model = _ws_send_audio_and_collect(client, mock_model_provider, gen_fn)
326327

327-
# Check that generate was called with a numpy array (not a string path)
328+
# Check that generate was called with an mx.array (not a string path)
328329
tracked = mock_stt_model.generate
329330
assert len(tracked.call_args_list) > 0, "generate was never called"
330331
first_arg = tracked.call_args_list[0][0][
331332
0
332333
] # first call, positional args, first arg
333-
assert isinstance(
334-
first_arg, np.ndarray
335-
), f"Expected numpy array, got {type(first_arg)}"
334+
assert isinstance(first_arg, mx.array), f"Expected mx.array, got {type(first_arg)}"
335+
336+
337+
def test_realtime_ws_mx_array_supports_bfloat16_cast(client, mock_model_provider):
338+
"""Regression: models like Parakeet that cast to bfloat16 must receive mx.array."""
339+
340+
def gen_fn(audio, *, stream=False, language=None, verbose=False, **kwargs):
341+
if stream:
342+
# Parakeet's stream_generate does this internally
343+
_ = audio.astype(mx.bfloat16)
344+
return iter(["ok"])
345+
return MagicMock(text="ok", segments=None, language=None)
346+
347+
messages, _ = _ws_send_audio_and_collect(client, mock_model_provider, gen_fn)
348+
completes = [m for m in messages if m.get("type") == "complete"]
349+
assert len(completes) >= 1
350+
assert completes[0]["text"] == "ok"
336351

337352

338353
def test_realtime_ws_streaming_disabled_fallback(client, mock_model_provider):

0 commit comments

Comments
 (0)