Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cartesia/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class OutputFormatMapping:
"raw_pcm_s16le_8000": {"container": "raw", "encoding": "pcm_s16le", "sample_rate": 8000},
"raw_pcm_mulaw_8000": {"container": "raw", "encoding": "pcm_mulaw", "sample_rate": 8000},
"raw_pcm_alaw_8000": {"container": "raw", "encoding": "pcm_alaw", "sample_rate": 8000},
"mp3_128kbps_44100": {"container": "mp3", "bit_rate": 128000, "sample_rate": 44100},
}

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions cartesia/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def get_output_format(output_format_name: str) -> OutputFormat:
output_format_obj = OutputFormatMapping.get_format(output_format_name)
else:
raise ValueError(f"Unsupported format: {output_format_name}")

if output_format_obj["container"].lower() == "mp3":
return OutputFormat(
container=output_format_obj["container"],
bit_rate=output_format_obj["bit_rate"],
sample_rate=output_format_obj["sample_rate"],
)

return OutputFormat(
container=output_format_obj["container"],
Expand Down
7 changes: 6 additions & 1 deletion tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ async def test_continuation_websocket_context_three_contexts_parallel():
"raw_pcm_s16le_8000",
"raw_pcm_mulaw_8000",
"raw_pcm_alaw_8000",
"mp3_128kbps_44100"
]

deprecated_output_format_names = [
Expand All @@ -1021,8 +1022,12 @@ def test_output_formats(resources: _Resources, output_format_name: str):
output_format = resources.client.tts.get_output_format(output_format_name)
assert isinstance(output_format, dict), "Output is not of type dict"
assert output_format["container"] is not None, "Output format container is None"
assert output_format["encoding"] is not None, "Output format encoding is None"
assert output_format["sample_rate"] is not None, "Output format sample rate is None"
if output_format["container"]=="mp3":
assert output_format["bit_rate"] is not None, "Output format bit rate is None"
else:
assert output_format["encoding"] is not None, "Output format encoding is None"


def test_invalid_output_format(resources: _Resources):
logger.info("Testing invalid output format")
Expand Down