Skip to content

Commit 77f73ac

Browse files
committed
feat: add tests for WebSocket streaming with reference_id and references parameters
Signed-off-by: James Ding <[email protected]>
1 parent c65fcaa commit 77f73ac

File tree

1 file changed

+319
-1
lines changed

1 file changed

+319
-1
lines changed

tests/unit/test_tts_realtime.py

Lines changed: 319 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from fishaudio.core import ClientWrapper, AsyncClientWrapper
77
from fishaudio.resources.tts import TTSClient, AsyncTTSClient
8-
from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent
8+
from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent, ReferenceAudio
9+
import ormsgpack
910

1011

1112
@pytest.fixture
@@ -181,6 +182,169 @@ def test_stream_websocket_max_workers(
181182
# Verify ThreadPoolExecutor was created with max_workers=5
182183
mock_executor.assert_called_once_with(max_workers=5)
183184

185+
@patch("fishaudio.resources.tts.connect_ws")
186+
@patch("fishaudio.resources.tts.ThreadPoolExecutor")
187+
def test_stream_websocket_with_reference_id_parameter(
188+
self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper
189+
):
190+
"""Test WebSocket streaming with reference_id as direct parameter."""
191+
# Setup mocks
192+
mock_ws = MagicMock()
193+
mock_ws.__enter__ = Mock(return_value=mock_ws)
194+
mock_ws.__exit__ = Mock(return_value=None)
195+
mock_ws.send_bytes = Mock()
196+
mock_connect_ws.return_value = mock_ws
197+
198+
# Make executor.submit actually run the function
199+
def submit_side_effect(fn):
200+
fn() # Execute the sender function
201+
mock_future = Mock()
202+
mock_future.result.return_value = None
203+
return mock_future
204+
205+
mock_executor_instance = Mock()
206+
mock_executor_instance.submit.side_effect = submit_side_effect
207+
mock_executor.return_value = mock_executor_instance
208+
209+
with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver:
210+
mock_receiver.return_value = iter([b"audio"])
211+
212+
text_stream = iter(["Test"])
213+
list(tts_client.stream_websocket(text_stream, reference_id="voice_456"))
214+
215+
# Verify WebSocket was called with StartEvent containing reference_id
216+
assert mock_ws.send_bytes.called
217+
# Get the first call (StartEvent)
218+
first_call = mock_ws.send_bytes.call_args_list[0]
219+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
220+
assert start_event_payload["request"]["reference_id"] == "voice_456"
221+
222+
@patch("fishaudio.resources.tts.connect_ws")
223+
@patch("fishaudio.resources.tts.ThreadPoolExecutor")
224+
def test_stream_websocket_config_reference_id_overrides_parameter(
225+
self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper
226+
):
227+
"""Test that config.reference_id overrides parameter reference_id."""
228+
# Setup mocks
229+
mock_ws = MagicMock()
230+
mock_ws.__enter__ = Mock(return_value=mock_ws)
231+
mock_ws.__exit__ = Mock(return_value=None)
232+
mock_ws.send_bytes = Mock()
233+
mock_connect_ws.return_value = mock_ws
234+
235+
# Make executor.submit actually run the function
236+
def submit_side_effect(fn):
237+
fn() # Execute the sender function
238+
mock_future = Mock()
239+
mock_future.result.return_value = None
240+
return mock_future
241+
242+
mock_executor_instance = Mock()
243+
mock_executor_instance.submit.side_effect = submit_side_effect
244+
mock_executor.return_value = mock_executor_instance
245+
246+
with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver:
247+
mock_receiver.return_value = iter([b"audio"])
248+
249+
config = TTSConfig(reference_id="voice_from_config")
250+
text_stream = iter(["Test"])
251+
list(
252+
tts_client.stream_websocket(
253+
text_stream, reference_id="voice_from_param", config=config
254+
)
255+
)
256+
257+
# Verify config reference_id takes precedence
258+
first_call = mock_ws.send_bytes.call_args_list[0]
259+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
260+
assert start_event_payload["request"]["reference_id"] == "voice_from_config"
261+
262+
@patch("fishaudio.resources.tts.connect_ws")
263+
@patch("fishaudio.resources.tts.ThreadPoolExecutor")
264+
def test_stream_websocket_with_references_parameter(
265+
self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper
266+
):
267+
"""Test WebSocket streaming with references as direct parameter."""
268+
# Setup mocks
269+
mock_ws = MagicMock()
270+
mock_ws.__enter__ = Mock(return_value=mock_ws)
271+
mock_ws.__exit__ = Mock(return_value=None)
272+
mock_ws.send_bytes = Mock()
273+
mock_connect_ws.return_value = mock_ws
274+
275+
# Make executor.submit actually run the function
276+
def submit_side_effect(fn):
277+
fn() # Execute the sender function
278+
mock_future = Mock()
279+
mock_future.result.return_value = None
280+
return mock_future
281+
282+
mock_executor_instance = Mock()
283+
mock_executor_instance.submit.side_effect = submit_side_effect
284+
mock_executor.return_value = mock_executor_instance
285+
286+
with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver:
287+
mock_receiver.return_value = iter([b"audio"])
288+
289+
references = [
290+
ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"),
291+
ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"),
292+
]
293+
294+
text_stream = iter(["Test"])
295+
list(tts_client.stream_websocket(text_stream, references=references))
296+
297+
# Verify references in StartEvent
298+
first_call = mock_ws.send_bytes.call_args_list[0]
299+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
300+
assert len(start_event_payload["request"]["references"]) == 2
301+
assert start_event_payload["request"]["references"][0]["text"] == "Sample 1"
302+
assert start_event_payload["request"]["references"][1]["text"] == "Sample 2"
303+
304+
@patch("fishaudio.resources.tts.connect_ws")
305+
@patch("fishaudio.resources.tts.ThreadPoolExecutor")
306+
def test_stream_websocket_config_references_overrides_parameter(
307+
self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper
308+
):
309+
"""Test that config.references overrides parameter references."""
310+
# Setup mocks
311+
mock_ws = MagicMock()
312+
mock_ws.__enter__ = Mock(return_value=mock_ws)
313+
mock_ws.__exit__ = Mock(return_value=None)
314+
mock_ws.send_bytes = Mock()
315+
mock_connect_ws.return_value = mock_ws
316+
317+
# Make executor.submit actually run the function
318+
def submit_side_effect(fn):
319+
fn() # Execute the sender function
320+
mock_future = Mock()
321+
mock_future.result.return_value = None
322+
return mock_future
323+
324+
mock_executor_instance = Mock()
325+
mock_executor_instance.submit.side_effect = submit_side_effect
326+
mock_executor.return_value = mock_executor_instance
327+
328+
with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver:
329+
mock_receiver.return_value = iter([b"audio"])
330+
331+
config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")]
332+
param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")]
333+
334+
config = TTSConfig(references=config_refs)
335+
text_stream = iter(["Test"])
336+
list(
337+
tts_client.stream_websocket(
338+
text_stream, references=param_refs, config=config
339+
)
340+
)
341+
342+
# Verify config references take precedence
343+
first_call = mock_ws.send_bytes.call_args_list[0]
344+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
345+
assert len(start_event_payload["request"]["references"]) == 1
346+
assert start_event_payload["request"]["references"][0]["text"] == "Config"
347+
184348

185349
class TestAsyncTTSRealtimeClient:
186350
"""Test asynchronous AsyncTTSClient realtime streaming."""
@@ -331,3 +495,157 @@ async def text_stream():
331495

332496
# Should have no audio
333497
assert audio_chunks == []
498+
499+
@pytest.mark.asyncio
500+
@patch("fishaudio.resources.tts.aconnect_ws")
501+
async def test_stream_websocket_with_reference_id_parameter(
502+
self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper
503+
):
504+
"""Test async WebSocket streaming with reference_id as direct parameter."""
505+
# Setup mocks
506+
mock_ws = MagicMock()
507+
mock_ws.__aenter__ = AsyncMock(return_value=mock_ws)
508+
mock_ws.__aexit__ = AsyncMock(return_value=None)
509+
mock_ws.send_bytes = AsyncMock()
510+
mock_aconnect_ws.return_value = mock_ws
511+
512+
async def mock_audio_receiver(ws):
513+
yield b"audio"
514+
515+
with patch(
516+
"fishaudio.resources.tts.aiter_websocket_audio",
517+
return_value=mock_audio_receiver(mock_ws),
518+
):
519+
520+
async def text_stream():
521+
yield "Test"
522+
523+
audio_chunks = []
524+
async for chunk in async_tts_client.stream_websocket(
525+
text_stream(), reference_id="voice_456"
526+
):
527+
audio_chunks.append(chunk)
528+
529+
# Verify WebSocket was called with StartEvent containing reference_id
530+
assert mock_ws.send_bytes.called
531+
# Get the first call (StartEvent)
532+
first_call = mock_ws.send_bytes.call_args_list[0]
533+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
534+
assert start_event_payload["request"]["reference_id"] == "voice_456"
535+
536+
@pytest.mark.asyncio
537+
@patch("fishaudio.resources.tts.aconnect_ws")
538+
async def test_stream_websocket_config_reference_id_overrides_parameter(
539+
self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper
540+
):
541+
"""Test that config.reference_id overrides parameter reference_id (async)."""
542+
# Setup mocks
543+
mock_ws = MagicMock()
544+
mock_ws.__aenter__ = AsyncMock(return_value=mock_ws)
545+
mock_ws.__aexit__ = AsyncMock(return_value=None)
546+
mock_ws.send_bytes = AsyncMock()
547+
mock_aconnect_ws.return_value = mock_ws
548+
549+
async def mock_audio_receiver(ws):
550+
yield b"audio"
551+
552+
with patch(
553+
"fishaudio.resources.tts.aiter_websocket_audio",
554+
return_value=mock_audio_receiver(mock_ws),
555+
):
556+
config = TTSConfig(reference_id="voice_from_config")
557+
558+
async def text_stream():
559+
yield "Test"
560+
561+
audio_chunks = []
562+
async for chunk in async_tts_client.stream_websocket(
563+
text_stream(), reference_id="voice_from_param", config=config
564+
):
565+
audio_chunks.append(chunk)
566+
567+
# Verify config reference_id takes precedence
568+
first_call = mock_ws.send_bytes.call_args_list[0]
569+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
570+
assert start_event_payload["request"]["reference_id"] == "voice_from_config"
571+
572+
@pytest.mark.asyncio
573+
@patch("fishaudio.resources.tts.aconnect_ws")
574+
async def test_stream_websocket_with_references_parameter(
575+
self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper
576+
):
577+
"""Test async WebSocket streaming with references as direct parameter."""
578+
# Setup mocks
579+
mock_ws = MagicMock()
580+
mock_ws.__aenter__ = AsyncMock(return_value=mock_ws)
581+
mock_ws.__aexit__ = AsyncMock(return_value=None)
582+
mock_ws.send_bytes = AsyncMock()
583+
mock_aconnect_ws.return_value = mock_ws
584+
585+
async def mock_audio_receiver(ws):
586+
yield b"audio"
587+
588+
with patch(
589+
"fishaudio.resources.tts.aiter_websocket_audio",
590+
return_value=mock_audio_receiver(mock_ws),
591+
):
592+
references = [
593+
ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"),
594+
ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"),
595+
]
596+
597+
async def text_stream():
598+
yield "Test"
599+
600+
audio_chunks = []
601+
async for chunk in async_tts_client.stream_websocket(
602+
text_stream(), references=references
603+
):
604+
audio_chunks.append(chunk)
605+
606+
# Verify references in StartEvent
607+
first_call = mock_ws.send_bytes.call_args_list[0]
608+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
609+
assert len(start_event_payload["request"]["references"]) == 2
610+
assert start_event_payload["request"]["references"][0]["text"] == "Sample 1"
611+
assert start_event_payload["request"]["references"][1]["text"] == "Sample 2"
612+
613+
@pytest.mark.asyncio
614+
@patch("fishaudio.resources.tts.aconnect_ws")
615+
async def test_stream_websocket_config_references_overrides_parameter(
616+
self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper
617+
):
618+
"""Test that config.references overrides parameter references (async)."""
619+
# Setup mocks
620+
mock_ws = MagicMock()
621+
mock_ws.__aenter__ = AsyncMock(return_value=mock_ws)
622+
mock_ws.__aexit__ = AsyncMock(return_value=None)
623+
mock_ws.send_bytes = AsyncMock()
624+
mock_aconnect_ws.return_value = mock_ws
625+
626+
async def mock_audio_receiver(ws):
627+
yield b"audio"
628+
629+
with patch(
630+
"fishaudio.resources.tts.aiter_websocket_audio",
631+
return_value=mock_audio_receiver(mock_ws),
632+
):
633+
config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")]
634+
param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")]
635+
636+
config = TTSConfig(references=config_refs)
637+
638+
async def text_stream():
639+
yield "Test"
640+
641+
audio_chunks = []
642+
async for chunk in async_tts_client.stream_websocket(
643+
text_stream(), references=param_refs, config=config
644+
):
645+
audio_chunks.append(chunk)
646+
647+
# Verify config references take precedence
648+
first_call = mock_ws.send_bytes.call_args_list[0]
649+
start_event_payload = ormsgpack.unpackb(first_call[0][0])
650+
assert len(start_event_payload["request"]["references"]) == 1
651+
assert start_event_payload["request"]["references"][0]["text"] == "Config"

0 commit comments

Comments
 (0)