|
5 | 5 |
|
6 | 6 | from fishaudio.core import ClientWrapper, AsyncClientWrapper |
7 | 7 | from fishaudio.resources.tts import TTSClient, AsyncTTSClient |
8 | | -from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent |
| 8 | +from fishaudio.types import Prosody, TTSConfig, TextEvent, FlushEvent, ReferenceAudio |
| 9 | +import ormsgpack |
9 | 10 |
|
10 | 11 |
|
11 | 12 | @pytest.fixture |
@@ -181,6 +182,169 @@ def test_stream_websocket_max_workers( |
181 | 182 | # Verify ThreadPoolExecutor was created with max_workers=5 |
182 | 183 | mock_executor.assert_called_once_with(max_workers=5) |
183 | 184 |
|
| 185 | + @patch("fishaudio.resources.tts.connect_ws") |
| 186 | + @patch("fishaudio.resources.tts.ThreadPoolExecutor") |
| 187 | + def test_stream_websocket_with_reference_id_parameter( |
| 188 | + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper |
| 189 | + ): |
| 190 | + """Test WebSocket streaming with reference_id as direct parameter.""" |
| 191 | + # Setup mocks |
| 192 | + mock_ws = MagicMock() |
| 193 | + mock_ws.__enter__ = Mock(return_value=mock_ws) |
| 194 | + mock_ws.__exit__ = Mock(return_value=None) |
| 195 | + mock_ws.send_bytes = Mock() |
| 196 | + mock_connect_ws.return_value = mock_ws |
| 197 | + |
| 198 | + # Make executor.submit actually run the function |
| 199 | + def submit_side_effect(fn): |
| 200 | + fn() # Execute the sender function |
| 201 | + mock_future = Mock() |
| 202 | + mock_future.result.return_value = None |
| 203 | + return mock_future |
| 204 | + |
| 205 | + mock_executor_instance = Mock() |
| 206 | + mock_executor_instance.submit.side_effect = submit_side_effect |
| 207 | + mock_executor.return_value = mock_executor_instance |
| 208 | + |
| 209 | + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: |
| 210 | + mock_receiver.return_value = iter([b"audio"]) |
| 211 | + |
| 212 | + text_stream = iter(["Test"]) |
| 213 | + list(tts_client.stream_websocket(text_stream, reference_id="voice_456")) |
| 214 | + |
| 215 | + # Verify WebSocket was called with StartEvent containing reference_id |
| 216 | + assert mock_ws.send_bytes.called |
| 217 | + # Get the first call (StartEvent) |
| 218 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 219 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 220 | + assert start_event_payload["request"]["reference_id"] == "voice_456" |
| 221 | + |
| 222 | + @patch("fishaudio.resources.tts.connect_ws") |
| 223 | + @patch("fishaudio.resources.tts.ThreadPoolExecutor") |
| 224 | + def test_stream_websocket_config_reference_id_overrides_parameter( |
| 225 | + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper |
| 226 | + ): |
| 227 | + """Test that config.reference_id overrides parameter reference_id.""" |
| 228 | + # Setup mocks |
| 229 | + mock_ws = MagicMock() |
| 230 | + mock_ws.__enter__ = Mock(return_value=mock_ws) |
| 231 | + mock_ws.__exit__ = Mock(return_value=None) |
| 232 | + mock_ws.send_bytes = Mock() |
| 233 | + mock_connect_ws.return_value = mock_ws |
| 234 | + |
| 235 | + # Make executor.submit actually run the function |
| 236 | + def submit_side_effect(fn): |
| 237 | + fn() # Execute the sender function |
| 238 | + mock_future = Mock() |
| 239 | + mock_future.result.return_value = None |
| 240 | + return mock_future |
| 241 | + |
| 242 | + mock_executor_instance = Mock() |
| 243 | + mock_executor_instance.submit.side_effect = submit_side_effect |
| 244 | + mock_executor.return_value = mock_executor_instance |
| 245 | + |
| 246 | + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: |
| 247 | + mock_receiver.return_value = iter([b"audio"]) |
| 248 | + |
| 249 | + config = TTSConfig(reference_id="voice_from_config") |
| 250 | + text_stream = iter(["Test"]) |
| 251 | + list( |
| 252 | + tts_client.stream_websocket( |
| 253 | + text_stream, reference_id="voice_from_param", config=config |
| 254 | + ) |
| 255 | + ) |
| 256 | + |
| 257 | + # Verify config reference_id takes precedence |
| 258 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 259 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 260 | + assert start_event_payload["request"]["reference_id"] == "voice_from_config" |
| 261 | + |
| 262 | + @patch("fishaudio.resources.tts.connect_ws") |
| 263 | + @patch("fishaudio.resources.tts.ThreadPoolExecutor") |
| 264 | + def test_stream_websocket_with_references_parameter( |
| 265 | + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper |
| 266 | + ): |
| 267 | + """Test WebSocket streaming with references as direct parameter.""" |
| 268 | + # Setup mocks |
| 269 | + mock_ws = MagicMock() |
| 270 | + mock_ws.__enter__ = Mock(return_value=mock_ws) |
| 271 | + mock_ws.__exit__ = Mock(return_value=None) |
| 272 | + mock_ws.send_bytes = Mock() |
| 273 | + mock_connect_ws.return_value = mock_ws |
| 274 | + |
| 275 | + # Make executor.submit actually run the function |
| 276 | + def submit_side_effect(fn): |
| 277 | + fn() # Execute the sender function |
| 278 | + mock_future = Mock() |
| 279 | + mock_future.result.return_value = None |
| 280 | + return mock_future |
| 281 | + |
| 282 | + mock_executor_instance = Mock() |
| 283 | + mock_executor_instance.submit.side_effect = submit_side_effect |
| 284 | + mock_executor.return_value = mock_executor_instance |
| 285 | + |
| 286 | + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: |
| 287 | + mock_receiver.return_value = iter([b"audio"]) |
| 288 | + |
| 289 | + references = [ |
| 290 | + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), |
| 291 | + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), |
| 292 | + ] |
| 293 | + |
| 294 | + text_stream = iter(["Test"]) |
| 295 | + list(tts_client.stream_websocket(text_stream, references=references)) |
| 296 | + |
| 297 | + # Verify references in StartEvent |
| 298 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 299 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 300 | + assert len(start_event_payload["request"]["references"]) == 2 |
| 301 | + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" |
| 302 | + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" |
| 303 | + |
| 304 | + @patch("fishaudio.resources.tts.connect_ws") |
| 305 | + @patch("fishaudio.resources.tts.ThreadPoolExecutor") |
| 306 | + def test_stream_websocket_config_references_overrides_parameter( |
| 307 | + self, mock_executor, mock_connect_ws, tts_client, mock_client_wrapper |
| 308 | + ): |
| 309 | + """Test that config.references overrides parameter references.""" |
| 310 | + # Setup mocks |
| 311 | + mock_ws = MagicMock() |
| 312 | + mock_ws.__enter__ = Mock(return_value=mock_ws) |
| 313 | + mock_ws.__exit__ = Mock(return_value=None) |
| 314 | + mock_ws.send_bytes = Mock() |
| 315 | + mock_connect_ws.return_value = mock_ws |
| 316 | + |
| 317 | + # Make executor.submit actually run the function |
| 318 | + def submit_side_effect(fn): |
| 319 | + fn() # Execute the sender function |
| 320 | + mock_future = Mock() |
| 321 | + mock_future.result.return_value = None |
| 322 | + return mock_future |
| 323 | + |
| 324 | + mock_executor_instance = Mock() |
| 325 | + mock_executor_instance.submit.side_effect = submit_side_effect |
| 326 | + mock_executor.return_value = mock_executor_instance |
| 327 | + |
| 328 | + with patch("fishaudio.resources.tts.iter_websocket_audio") as mock_receiver: |
| 329 | + mock_receiver.return_value = iter([b"audio"]) |
| 330 | + |
| 331 | + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] |
| 332 | + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] |
| 333 | + |
| 334 | + config = TTSConfig(references=config_refs) |
| 335 | + text_stream = iter(["Test"]) |
| 336 | + list( |
| 337 | + tts_client.stream_websocket( |
| 338 | + text_stream, references=param_refs, config=config |
| 339 | + ) |
| 340 | + ) |
| 341 | + |
| 342 | + # Verify config references take precedence |
| 343 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 344 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 345 | + assert len(start_event_payload["request"]["references"]) == 1 |
| 346 | + assert start_event_payload["request"]["references"][0]["text"] == "Config" |
| 347 | + |
184 | 348 |
|
185 | 349 | class TestAsyncTTSRealtimeClient: |
186 | 350 | """Test asynchronous AsyncTTSClient realtime streaming.""" |
@@ -331,3 +495,157 @@ async def text_stream(): |
331 | 495 |
|
332 | 496 | # Should have no audio |
333 | 497 | assert audio_chunks == [] |
| 498 | + |
| 499 | + @pytest.mark.asyncio |
| 500 | + @patch("fishaudio.resources.tts.aconnect_ws") |
| 501 | + async def test_stream_websocket_with_reference_id_parameter( |
| 502 | + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper |
| 503 | + ): |
| 504 | + """Test async WebSocket streaming with reference_id as direct parameter.""" |
| 505 | + # Setup mocks |
| 506 | + mock_ws = MagicMock() |
| 507 | + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) |
| 508 | + mock_ws.__aexit__ = AsyncMock(return_value=None) |
| 509 | + mock_ws.send_bytes = AsyncMock() |
| 510 | + mock_aconnect_ws.return_value = mock_ws |
| 511 | + |
| 512 | + async def mock_audio_receiver(ws): |
| 513 | + yield b"audio" |
| 514 | + |
| 515 | + with patch( |
| 516 | + "fishaudio.resources.tts.aiter_websocket_audio", |
| 517 | + return_value=mock_audio_receiver(mock_ws), |
| 518 | + ): |
| 519 | + |
| 520 | + async def text_stream(): |
| 521 | + yield "Test" |
| 522 | + |
| 523 | + audio_chunks = [] |
| 524 | + async for chunk in async_tts_client.stream_websocket( |
| 525 | + text_stream(), reference_id="voice_456" |
| 526 | + ): |
| 527 | + audio_chunks.append(chunk) |
| 528 | + |
| 529 | + # Verify WebSocket was called with StartEvent containing reference_id |
| 530 | + assert mock_ws.send_bytes.called |
| 531 | + # Get the first call (StartEvent) |
| 532 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 533 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 534 | + assert start_event_payload["request"]["reference_id"] == "voice_456" |
| 535 | + |
| 536 | + @pytest.mark.asyncio |
| 537 | + @patch("fishaudio.resources.tts.aconnect_ws") |
| 538 | + async def test_stream_websocket_config_reference_id_overrides_parameter( |
| 539 | + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper |
| 540 | + ): |
| 541 | + """Test that config.reference_id overrides parameter reference_id (async).""" |
| 542 | + # Setup mocks |
| 543 | + mock_ws = MagicMock() |
| 544 | + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) |
| 545 | + mock_ws.__aexit__ = AsyncMock(return_value=None) |
| 546 | + mock_ws.send_bytes = AsyncMock() |
| 547 | + mock_aconnect_ws.return_value = mock_ws |
| 548 | + |
| 549 | + async def mock_audio_receiver(ws): |
| 550 | + yield b"audio" |
| 551 | + |
| 552 | + with patch( |
| 553 | + "fishaudio.resources.tts.aiter_websocket_audio", |
| 554 | + return_value=mock_audio_receiver(mock_ws), |
| 555 | + ): |
| 556 | + config = TTSConfig(reference_id="voice_from_config") |
| 557 | + |
| 558 | + async def text_stream(): |
| 559 | + yield "Test" |
| 560 | + |
| 561 | + audio_chunks = [] |
| 562 | + async for chunk in async_tts_client.stream_websocket( |
| 563 | + text_stream(), reference_id="voice_from_param", config=config |
| 564 | + ): |
| 565 | + audio_chunks.append(chunk) |
| 566 | + |
| 567 | + # Verify config reference_id takes precedence |
| 568 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 569 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 570 | + assert start_event_payload["request"]["reference_id"] == "voice_from_config" |
| 571 | + |
| 572 | + @pytest.mark.asyncio |
| 573 | + @patch("fishaudio.resources.tts.aconnect_ws") |
| 574 | + async def test_stream_websocket_with_references_parameter( |
| 575 | + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper |
| 576 | + ): |
| 577 | + """Test async WebSocket streaming with references as direct parameter.""" |
| 578 | + # Setup mocks |
| 579 | + mock_ws = MagicMock() |
| 580 | + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) |
| 581 | + mock_ws.__aexit__ = AsyncMock(return_value=None) |
| 582 | + mock_ws.send_bytes = AsyncMock() |
| 583 | + mock_aconnect_ws.return_value = mock_ws |
| 584 | + |
| 585 | + async def mock_audio_receiver(ws): |
| 586 | + yield b"audio" |
| 587 | + |
| 588 | + with patch( |
| 589 | + "fishaudio.resources.tts.aiter_websocket_audio", |
| 590 | + return_value=mock_audio_receiver(mock_ws), |
| 591 | + ): |
| 592 | + references = [ |
| 593 | + ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"), |
| 594 | + ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"), |
| 595 | + ] |
| 596 | + |
| 597 | + async def text_stream(): |
| 598 | + yield "Test" |
| 599 | + |
| 600 | + audio_chunks = [] |
| 601 | + async for chunk in async_tts_client.stream_websocket( |
| 602 | + text_stream(), references=references |
| 603 | + ): |
| 604 | + audio_chunks.append(chunk) |
| 605 | + |
| 606 | + # Verify references in StartEvent |
| 607 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 608 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 609 | + assert len(start_event_payload["request"]["references"]) == 2 |
| 610 | + assert start_event_payload["request"]["references"][0]["text"] == "Sample 1" |
| 611 | + assert start_event_payload["request"]["references"][1]["text"] == "Sample 2" |
| 612 | + |
| 613 | + @pytest.mark.asyncio |
| 614 | + @patch("fishaudio.resources.tts.aconnect_ws") |
| 615 | + async def test_stream_websocket_config_references_overrides_parameter( |
| 616 | + self, mock_aconnect_ws, async_tts_client, async_mock_client_wrapper |
| 617 | + ): |
| 618 | + """Test that config.references overrides parameter references (async).""" |
| 619 | + # Setup mocks |
| 620 | + mock_ws = MagicMock() |
| 621 | + mock_ws.__aenter__ = AsyncMock(return_value=mock_ws) |
| 622 | + mock_ws.__aexit__ = AsyncMock(return_value=None) |
| 623 | + mock_ws.send_bytes = AsyncMock() |
| 624 | + mock_aconnect_ws.return_value = mock_ws |
| 625 | + |
| 626 | + async def mock_audio_receiver(ws): |
| 627 | + yield b"audio" |
| 628 | + |
| 629 | + with patch( |
| 630 | + "fishaudio.resources.tts.aiter_websocket_audio", |
| 631 | + return_value=mock_audio_receiver(mock_ws), |
| 632 | + ): |
| 633 | + config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")] |
| 634 | + param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")] |
| 635 | + |
| 636 | + config = TTSConfig(references=config_refs) |
| 637 | + |
| 638 | + async def text_stream(): |
| 639 | + yield "Test" |
| 640 | + |
| 641 | + audio_chunks = [] |
| 642 | + async for chunk in async_tts_client.stream_websocket( |
| 643 | + text_stream(), references=param_refs, config=config |
| 644 | + ): |
| 645 | + audio_chunks.append(chunk) |
| 646 | + |
| 647 | + # Verify config references take precedence |
| 648 | + first_call = mock_ws.send_bytes.call_args_list[0] |
| 649 | + start_event_payload = ormsgpack.unpackb(first_call[0][0]) |
| 650 | + assert len(start_event_payload["request"]["references"]) == 1 |
| 651 | + assert start_event_payload["request"]["references"][0]["text"] == "Config" |
0 commit comments