Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[project]
name = "cartesia"
version = "2.0.15"

[tool.poetry]
name = "cartesia"
Expand Down
3 changes: 3 additions & 0 deletions src/cartesia/tts/_async_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def send(
use_original_timestamps: bool = False,
continue_: bool = False,
max_buffer_delay_ms: Optional[int] = None,
pronunciation_dict_id: Optional[str] = None,
flush: bool = False,
) -> None:
"""Send audio generation requests to the WebSocket. The response can be received using the `receive` method.
Expand Down Expand Up @@ -116,6 +117,8 @@ async def send(
request_body["max_buffer_delay_ms"] = max_buffer_delay_ms
if flush:
request_body["flush"] = flush
if pronunciation_dict_id:
request_body["pronunciation_dict_id"] = pronunciation_dict_id

if generation_config is not None:
if isinstance(generation_config, dict):
Expand Down
3 changes: 3 additions & 0 deletions src/cartesia/tts/_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def send(
add_timestamps: bool = False,
add_phoneme_timestamps: bool = False,
use_original_timestamps: bool = False,
pronunciation_dict_id: Optional[str] = None
) -> Generator[bytes, None, None]:
"""Send audio generation requests to the WebSocket and yield responses.

Expand Down Expand Up @@ -111,6 +112,8 @@ def send(
request_body["use_original_timestamps"] = use_original_timestamps
if max_buffer_delay_ms:
request_body["max_buffer_delay_ms"] = max_buffer_delay_ms
if pronunciation_dict_id:
request_body["pronunciation_dict_id"] = pronunciation_dict_id

if generation_config is not None:
if isinstance(generation_config, dict):
Expand Down
1 change: 1 addition & 0 deletions src/cartesia/tts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def bytes(
speed: typing.Optional[ModelSpeed] = OMIT,
pronunciation_dict_id: typing.Optional[str] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
pronunciation_dict_id: typing.Optional[str] = None
) -> typing.Iterator[bytes]:
"""
Parameters
Expand Down
57 changes: 57 additions & 0 deletions tests/custom/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,29 @@ def test_sse_err():
pass


def test_sse_pronunciation_dict(resources: _Resources):
logger.info("Testing SSE with pronunciation_dict_id parameter")
client = resources.client
transcript = SAMPLE_TRANSCRIPT

output_generate = client.tts.sse(
transcript=transcript,
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
model_id=DEFAULT_MODEL_ID,
pronunciation_dict_id=None, # Test with None to verify parameter acceptance
)

chunks = []
for response in output_generate:
assert isinstance(response, WebSocketResponse_Chunk)
audio_bytes = base64.b64decode(response.data)
chunks.append(audio_bytes)

data = b"".join(chunks)
_validate_audio_response(data, DEFAULT_OUTPUT_FORMAT_PARAMS)


@pytest.mark.parametrize("output_format", TEST_RAW_OUTPUT_FORMATS)
@pytest.mark.parametrize("stream", [True, False])
def test_ws_sync(resources: _Resources, output_format: OutputFormatParams, stream: bool):
Expand Down Expand Up @@ -584,6 +607,40 @@ async def test_ws_timestamps(use_original_timestamps: bool):
await async_client.close()


@pytest.mark.asyncio
async def test_ws_pronunciation_dict():
logger.info("Testing WebSocket with pronunciation_dict_id parameter")
transcript = SAMPLE_TRANSCRIPT

async_client = create_async_client()
ws = await async_client.tts.websocket()

# Test that pronunciation_dict_id parameter can be passed
# Using None as we don't have a real pronunciation dict ID for testing
output_generate = await ws.send(
transcript=transcript,
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
model_id=DEFAULT_MODEL_ID,
pronunciation_dict_id=None, # Test with None to verify parameter acceptance
stream=True,
)

chunks = []
async for out in output_generate:
_validate_schema(out)
if out.audio is not None:
chunks.append(out.audio)

# Verify audio
audio = b"".join(chunks)
_validate_audio_response(audio, DEFAULT_OUTPUT_FORMAT_PARAMS)

# Close the websocket
await ws.close()
await async_client.close()


def chunk_generator(transcripts):
for transcript in transcripts:
if transcript.endswith(" "):
Expand Down