Skip to content

Commit 9f3fac3

Browse files
noahltbpanahij
andauthored
pronunciation dict updates (#57)
Takes changes from #56 and adds support in bytes method and ws send wrapper methods. --------- Co-authored-by: Brian Johnson <brian@pjohnson.info>
1 parent 2855124 commit 9f3fac3

File tree

4 files changed

+72
-0
lines changed

4 files changed

+72
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[project]
22
name = "cartesia"
3+
version = "2.0.15"
34

45
[tool.poetry]
56
name = "cartesia"

src/cartesia/tts/_async_websocket.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ async def send(
7272
use_original_timestamps: bool = False,
7373
continue_: bool = False,
7474
max_buffer_delay_ms: Optional[int] = None,
75+
pronunciation_dict_id: Optional[str] = None,
7576
flush: bool = False,
7677
) -> None:
7778
"""Send audio generation requests to the WebSocket. The response can be received using the `receive` method.
@@ -116,6 +117,8 @@ async def send(
116117
request_body["max_buffer_delay_ms"] = max_buffer_delay_ms
117118
if flush:
118119
request_body["flush"] = flush
120+
if pronunciation_dict_id:
121+
request_body["pronunciation_dict_id"] = pronunciation_dict_id
119122

120123
if generation_config is not None:
121124
if isinstance(generation_config, dict):
@@ -383,6 +386,7 @@ async def send(
383386
add_timestamps: bool = False,
384387
add_phoneme_timestamps: bool = False,
385388
use_original_timestamps: bool = False,
389+
pronunciation_dict_id: Optional[str] = None,
386390
):
387391
"""See :meth:`_WebSocket.send` for details."""
388392
if context_id is None:
@@ -403,6 +407,7 @@ async def send(
403407
add_timestamps=add_timestamps,
404408
add_phoneme_timestamps=add_phoneme_timestamps,
405409
use_original_timestamps=use_original_timestamps,
410+
pronunciation_dict_id=pronunciation_dict_id,
406411
)
407412

408413
generator = ctx.receive()

src/cartesia/tts/_websocket.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def send(
7070
add_timestamps: bool = False,
7171
add_phoneme_timestamps: bool = False,
7272
use_original_timestamps: bool = False,
73+
pronunciation_dict_id: Optional[str] = None
7374
) -> Generator[bytes, None, None]:
7475
"""Send audio generation requests to the WebSocket and yield responses.
7576
@@ -111,6 +112,8 @@ def send(
111112
request_body["use_original_timestamps"] = use_original_timestamps
112113
if max_buffer_delay_ms:
113114
request_body["max_buffer_delay_ms"] = max_buffer_delay_ms
115+
if pronunciation_dict_id:
116+
request_body["pronunciation_dict_id"] = pronunciation_dict_id
114117

115118
if generation_config is not None:
116119
if isinstance(generation_config, dict):
@@ -370,6 +373,7 @@ def send(
370373
add_timestamps: bool = False,
371374
add_phoneme_timestamps: bool = False,
372375
use_original_timestamps: bool = False,
376+
pronunciation_dict_id: Optional[str] = None,
373377
):
374378
"""Send a request to the WebSocket to generate audio.
375379
@@ -402,6 +406,7 @@ def send(
402406
"add_timestamps": add_timestamps,
403407
"add_phoneme_timestamps": add_phoneme_timestamps,
404408
"use_original_timestamps": use_original_timestamps,
409+
"pronunciation_dict_id": pronunciation_dict_id,
405410
}
406411
generator = self._websocket_generator(request_body)
407412

tests/custom/test_client.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,29 @@ def test_sse_err():
431431
pass
432432

433433

434+
def test_sse_pronunciation_dict(resources: _Resources):
435+
logger.info("Testing SSE with pronunciation_dict_id parameter")
436+
client = resources.client
437+
transcript = SAMPLE_TRANSCRIPT
438+
439+
output_generate = client.tts.sse(
440+
transcript=transcript,
441+
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
442+
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
443+
model_id=DEFAULT_MODEL_ID,
444+
pronunciation_dict_id=None, # Test with None to verify parameter acceptance
445+
)
446+
447+
chunks = []
448+
for response in output_generate:
449+
assert isinstance(response, WebSocketResponse_Chunk)
450+
audio_bytes = base64.b64decode(response.data)
451+
chunks.append(audio_bytes)
452+
453+
data = b"".join(chunks)
454+
_validate_audio_response(data, DEFAULT_OUTPUT_FORMAT_PARAMS)
455+
456+
434457
@pytest.mark.parametrize("output_format", TEST_RAW_OUTPUT_FORMATS)
435458
@pytest.mark.parametrize("stream", [True, False])
436459
def test_ws_sync(resources: _Resources, output_format: OutputFormatParams, stream: bool):
@@ -584,6 +607,40 @@ async def test_ws_timestamps(use_original_timestamps: bool):
584607
await async_client.close()
585608

586609

610+
@pytest.mark.asyncio
611+
async def test_ws_pronunciation_dict():
612+
logger.info("Testing WebSocket with pronunciation_dict_id parameter")
613+
transcript = SAMPLE_TRANSCRIPT
614+
615+
async_client = create_async_client()
616+
ws = await async_client.tts.websocket()
617+
618+
# Test that pronunciation_dict_id parameter can be passed
619+
# Using None as we don't have a real pronunciation dict ID for testing
620+
output_generate = await ws.send(
621+
transcript=transcript,
622+
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
623+
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
624+
model_id=DEFAULT_MODEL_ID,
625+
pronunciation_dict_id=None, # Test with None to verify parameter acceptance
626+
stream=True,
627+
)
628+
629+
chunks = []
630+
async for out in output_generate:
631+
_validate_schema(out)
632+
if out.audio is not None:
633+
chunks.append(out.audio)
634+
635+
# Verify audio
636+
audio = b"".join(chunks)
637+
_validate_audio_response(audio, DEFAULT_OUTPUT_FORMAT_PARAMS)
638+
639+
# Close the websocket
640+
await ws.close()
641+
await async_client.close()
642+
643+
587644
def chunk_generator(transcripts):
588645
for transcript in transcripts:
589646
if transcript.endswith(" "):
@@ -1364,6 +1421,7 @@ def test_ws_phoneme_timestamps():
13641421
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
13651422
model_id=DEFAULT_MODEL_ID,
13661423
add_phoneme_timestamps=True,
1424+
add_timestamps=True, # workaround, currently you need both add_timestamps and add_phoneme_timestamps to get phoneme timestamps
13671425
stream=True,
13681426
)
13691427
has_phoneme_timestamps = False
@@ -1407,6 +1465,7 @@ def test_continuation_phoneme_timestamps():
14071465
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
14081466
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
14091467
add_phoneme_timestamps=True,
1468+
add_timestamps=True, # workaround, currently you need both add_timestamps and add_phoneme_timestamps to get phoneme timestamps
14101469
)
14111470

14121471
has_phoneme_timestamps = False
@@ -1445,6 +1504,7 @@ async def test_ws_phoneme_timestamps_async():
14451504
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
14461505
model_id=DEFAULT_MODEL_ID,
14471506
add_phoneme_timestamps=True,
1507+
add_timestamps=True, # workaround, currently you need both add_timestamps and add_phoneme_timestamps to get phoneme timestamps
14481508
stream=True,
14491509
)
14501510
has_phoneme_timestamps = False
@@ -1491,6 +1551,7 @@ async def test_continuation_phoneme_timestamps_async():
14911551
voice={"mode": "id", "id": SAMPLE_VOICE_ID},
14921552
output_format=DEFAULT_OUTPUT_FORMAT_PARAMS,
14931553
add_phoneme_timestamps=True,
1554+
add_timestamps=True, # workaround, currently you need both add_timestamps and add_phoneme_timestamps to get phoneme timestamps
14941555
continue_=True,
14951556
)
14961557

0 commit comments

Comments
 (0)