diff --git a/pyproject.toml b/pyproject.toml index 1cb3805..30bbb7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [project] name = "cartesia" +version = "2.0.15" [tool.poetry] name = "cartesia" diff --git a/src/cartesia/tts/_async_websocket.py b/src/cartesia/tts/_async_websocket.py index 64ed0fd..9ff0eb6 100644 --- a/src/cartesia/tts/_async_websocket.py +++ b/src/cartesia/tts/_async_websocket.py @@ -72,6 +72,7 @@ async def send( use_original_timestamps: bool = False, continue_: bool = False, max_buffer_delay_ms: Optional[int] = None, + pronunciation_dict_id: Optional[str] = None, flush: bool = False, ) -> None: """Send audio generation requests to the WebSocket. The response can be received using the `receive` method. @@ -116,6 +117,8 @@ async def send( request_body["max_buffer_delay_ms"] = max_buffer_delay_ms if flush: request_body["flush"] = flush + if pronunciation_dict_id: + request_body["pronunciation_dict_id"] = pronunciation_dict_id if generation_config is not None: if isinstance(generation_config, dict): diff --git a/src/cartesia/tts/_websocket.py b/src/cartesia/tts/_websocket.py index 5955664..0d136da 100644 --- a/src/cartesia/tts/_websocket.py +++ b/src/cartesia/tts/_websocket.py @@ -70,6 +70,7 @@ def send( add_timestamps: bool = False, add_phoneme_timestamps: bool = False, use_original_timestamps: bool = False, + pronunciation_dict_id: Optional[str] = None ) -> Generator[bytes, None, None]: """Send audio generation requests to the WebSocket and yield responses. @@ -111,6 +112,8 @@ def send( request_body["use_original_timestamps"] = use_original_timestamps if max_buffer_delay_ms: request_body["max_buffer_delay_ms"] = max_buffer_delay_ms + if pronunciation_dict_id: + request_body["pronunciation_dict_id"] = pronunciation_dict_id if generation_config is not None: if isinstance(generation_config, dict): diff --git a/src/cartesia/tts/client.py b/src/cartesia/tts/client.py index dc124c6..c5e0597 100644 --- a/src/cartesia/tts/client.py +++ b/src/cartesia/tts/client.py @@ -40,6 +40,7 @@ def bytes( speed: typing.Optional[ModelSpeed] = OMIT, pronunciation_dict_id: typing.Optional[str] = OMIT, request_options: typing.Optional[RequestOptions] = None, + pronunciation_dict_id: typing.Optional[str] = None ) -> typing.Iterator[bytes]: """ Parameters diff --git a/tests/custom/test_client.py b/tests/custom/test_client.py index c37c343..9108a86 100644 --- a/tests/custom/test_client.py +++ b/tests/custom/test_client.py @@ -431,6 +431,29 @@ def test_sse_err(): pass +def test_sse_pronunciation_dict(resources: _Resources): + logger.info("Testing SSE with pronunciation_dict_id parameter") + client = resources.client + transcript = SAMPLE_TRANSCRIPT + + output_generate = client.tts.sse( + transcript=transcript, + voice={"mode": "id", "id": SAMPLE_VOICE_ID}, + output_format=DEFAULT_OUTPUT_FORMAT_PARAMS, + model_id=DEFAULT_MODEL_ID, + pronunciation_dict_id=None, # Test with None to verify parameter acceptance + ) + + chunks = [] + for response in output_generate: + assert isinstance(response, WebSocketResponse_Chunk) + audio_bytes = base64.b64decode(response.data) + chunks.append(audio_bytes) + + data = b"".join(chunks) + _validate_audio_response(data, DEFAULT_OUTPUT_FORMAT_PARAMS) + + @pytest.mark.parametrize("output_format", TEST_RAW_OUTPUT_FORMATS) @pytest.mark.parametrize("stream", [True, False]) def test_ws_sync(resources: _Resources, output_format: OutputFormatParams, stream: bool): @@ -584,6 +607,40 @@ async def test_ws_timestamps(use_original_timestamps: bool): await async_client.close() +@pytest.mark.asyncio +async def test_ws_pronunciation_dict(): + logger.info("Testing WebSocket with pronunciation_dict_id parameter") + transcript = SAMPLE_TRANSCRIPT + + async_client = create_async_client() + ws = await async_client.tts.websocket() + + # Test that pronunciation_dict_id parameter can be passed + # Using None as we don't have a real pronunciation dict ID for testing + output_generate = await ws.send( + transcript=transcript, + voice={"mode": "id", "id": SAMPLE_VOICE_ID}, + output_format=DEFAULT_OUTPUT_FORMAT_PARAMS, + model_id=DEFAULT_MODEL_ID, + pronunciation_dict_id=None, # Test with None to verify parameter acceptance + stream=True, + ) + + chunks = [] + async for out in output_generate: + _validate_schema(out) + if out.audio is not None: + chunks.append(out.audio) + + # Verify audio + audio = b"".join(chunks) + _validate_audio_response(audio, DEFAULT_OUTPUT_FORMAT_PARAMS) + + # Close the websocket + await ws.close() + await async_client.close() + + def chunk_generator(transcripts): for transcript in transcripts: if transcript.endswith(" "):