|
2 | 2 | import io |
3 | 3 | from unittest.mock import AsyncMock, MagicMock, patch |
4 | 4 |
|
| 5 | +import mlx.core as mx |
5 | 6 | import numpy as np |
6 | 7 | import pytest |
7 | 8 |
|
@@ -319,20 +320,34 @@ def test_realtime_ws_streaming_structured_chunks(client, mock_model_provider): |
319 | 320 | assert "Hello" in combined |
320 | 321 |
|
321 | 322 |
|
322 | | -def test_realtime_ws_numpy_direct_pass(client, mock_model_provider): |
323 | | - """Streaming models receive numpy arrays directly, not file paths.""" |
| 323 | +def test_realtime_ws_mx_array_pass(client, mock_model_provider): |
| 324 | + """Streaming models receive mx.array, not file paths.""" |
324 | 325 | gen_fn = _make_streaming_generate(["test"]) |
325 | 326 | _, mock_stt_model = _ws_send_audio_and_collect(client, mock_model_provider, gen_fn) |
326 | 327 |
|
327 | | - # Check that generate was called with a numpy array (not a string path) |
| 328 | + # Check that generate was called with an mx.array (not a string path) |
328 | 329 | tracked = mock_stt_model.generate |
329 | 330 | assert len(tracked.call_args_list) > 0, "generate was never called" |
330 | 331 | first_arg = tracked.call_args_list[0][0][ |
331 | 332 | 0 |
332 | 333 | ] # first call, positional args, first arg |
333 | | - assert isinstance( |
334 | | - first_arg, np.ndarray |
335 | | - ), f"Expected numpy array, got {type(first_arg)}" |
| 334 | + assert isinstance(first_arg, mx.array), f"Expected mx.array, got {type(first_arg)}" |
| 335 | + |
| 336 | + |
| 337 | +def test_realtime_ws_mx_array_supports_bfloat16_cast(client, mock_model_provider): |
| 338 | + """Regression: models like Parakeet that cast to bfloat16 must receive mx.array.""" |
| 339 | + |
| 340 | + def gen_fn(audio, *, stream=False, language=None, verbose=False, **kwargs): |
| 341 | + if stream: |
| 342 | + # Parakeet's stream_generate does this internally |
| 343 | + _ = audio.astype(mx.bfloat16) |
| 344 | + return iter(["ok"]) |
| 345 | + return MagicMock(text="ok", segments=None, language=None) |
| 346 | + |
| 347 | + messages, _ = _ws_send_audio_and_collect(client, mock_model_provider, gen_fn) |
| 348 | + completes = [m for m in messages if m.get("type") == "complete"] |
| 349 | + assert len(completes) >= 1 |
| 350 | + assert completes[0]["text"] == "ok" |
336 | 351 |
|
337 | 352 |
|
338 | 353 | def test_realtime_ws_streaming_disabled_fallback(client, mock_model_provider): |
|
0 commit comments