Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions src/fishaudio/resources/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterable, Iterable, Iterator, Optional, Union
from typing import AsyncIterable, Iterable, Iterator, List, Optional, Union

import ormsgpack
from httpx_ws import AsyncWebSocketSession, WebSocketSession, aconnect_ws, connect_ws
Expand All @@ -13,6 +13,7 @@
CloseEvent,
FlushEvent,
Model,
ReferenceAudio,
StartEvent,
TextEvent,
TTSConfig,
Expand Down Expand Up @@ -58,6 +59,8 @@ def convert(
self,
*,
text: str,
reference_id: Optional[str] = None,
references: List[ReferenceAudio] = [],
config: TTSConfig = TTSConfig(),
model: Model = "s1",
request_options: Optional[RequestOptions] = None,
Expand All @@ -67,6 +70,8 @@ def convert(

Args:
text: Text to synthesize
reference_id: Voice reference ID (overridden by config.reference_id if set)
references: Reference audio samples (overridden by config.references if set)
config: TTS configuration (audio settings, voice, model parameters)
model: TTS model to use
request_options: Request-level overrides
Expand All @@ -76,13 +81,22 @@ def convert(

Example:
```python
from fishaudio import FishAudio, TTSConfig
from fishaudio import FishAudio, TTSConfig, ReferenceAudio

client = FishAudio(api_key="...")

# Simple usage with defaults
audio = client.tts.convert(text="Hello world")

# With reference_id parameter
audio = client.tts.convert(text="Hello world", reference_id="your_model_id")

# With references parameter
audio = client.tts.convert(
text="Hello world",
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
)

# Custom configuration
config = TTSConfig(format="wav", mp3_bitrate=192)
audio = client.tts.convert(text="Hello world", config=config)
Expand All @@ -94,6 +108,15 @@ def convert(
"""
# Build request payload from config
request = _config_to_tts_request(config, text)

# Use parameter reference_id only if config doesn't have one
if request.reference_id is None and reference_id is not None:
request.reference_id = reference_id

# Use parameter references only if config doesn't have any
if not request.references and references:
request.references = references

payload = request.model_dump(exclude_none=True)

# Make request with streaming
Expand All @@ -114,6 +137,8 @@ def stream_websocket(
self,
text_stream: Iterable[Union[str, TextEvent, FlushEvent]],
*,
reference_id: Optional[str] = None,
references: List[ReferenceAudio] = [],
config: TTSConfig = TTSConfig(),
model: Model = "s1",
max_workers: int = 10,
Expand All @@ -125,6 +150,8 @@ def stream_websocket(

Args:
text_stream: Iterator of text chunks to stream
reference_id: Voice reference ID (overridden by config.reference_id if set)
references: Reference audio samples (overridden by config.references if set)
config: TTS configuration (audio settings, voice, model parameters)
model: TTS model to use
max_workers: ThreadPoolExecutor workers for concurrent sender
Expand All @@ -134,7 +161,7 @@ def stream_websocket(

Example:
```python
from fishaudio import FishAudio, TTSConfig
from fishaudio import FishAudio, TTSConfig, ReferenceAudio

client = FishAudio(api_key="...")

Expand All @@ -148,6 +175,19 @@ def text_generator():
for audio_chunk in client.tts.stream_websocket(text_generator()):
f.write(audio_chunk)

# With reference_id parameter
with open("output.mp3", "wb") as f:
for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
f.write(audio_chunk)

# With references parameter
with open("output.mp3", "wb") as f:
for audio_chunk in client.tts.stream_websocket(
text_generator(),
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
):
f.write(audio_chunk)

# Custom configuration
config = TTSConfig(format="wav", latency="normal")
with open("output.wav", "wb") as f:
Expand All @@ -158,6 +198,14 @@ def text_generator():
# Build TTSRequest from config
tts_request = _config_to_tts_request(config, text="")

# Use parameter reference_id only if config doesn't have one
if tts_request.reference_id is None and reference_id is not None:
tts_request.reference_id = reference_id

# Use parameter references only if config doesn't have any
if not tts_request.references and references:
tts_request.references = references

executor = ThreadPoolExecutor(max_workers=max_workers)

try:
Expand Down Expand Up @@ -202,6 +250,8 @@ async def convert(
self,
*,
text: str,
reference_id: Optional[str] = None,
references: List[ReferenceAudio] = [],
config: TTSConfig = TTSConfig(),
model: Model = "s1",
request_options: Optional[RequestOptions] = None,
Expand All @@ -211,6 +261,8 @@ async def convert(

Args:
text: Text to synthesize
reference_id: Voice reference ID (overridden by config.reference_id if set)
references: Reference audio samples (overridden by config.references if set)
config: TTS configuration (audio settings, voice, model parameters)
model: TTS model to use
request_options: Request-level overrides
Expand All @@ -220,13 +272,22 @@ async def convert(

Example:
```python
from fishaudio import AsyncFishAudio, TTSConfig
from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio

client = AsyncFishAudio(api_key="...")

# Simple usage with defaults
audio = await client.tts.convert(text="Hello world")

# With reference_id parameter
audio = await client.tts.convert(text="Hello world", reference_id="your_model_id")

# With references parameter
audio = await client.tts.convert(
text="Hello world",
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
)

# Custom configuration
config = TTSConfig(format="wav", mp3_bitrate=192)
audio = await client.tts.convert(text="Hello world", config=config)
Expand All @@ -238,6 +299,15 @@ async def convert(
"""
# Build request payload from config
request = _config_to_tts_request(config, text)

# Use parameter reference_id only if config doesn't have one
if request.reference_id is None and reference_id is not None:
request.reference_id = reference_id

# Use parameter references only if config doesn't have any
if not request.references and references:
request.references = references

payload = request.model_dump(exclude_none=True)

# Make request with streaming
Expand All @@ -258,6 +328,8 @@ async def stream_websocket(
self,
text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]],
*,
reference_id: Optional[str] = None,
references: List[ReferenceAudio] = [],
config: TTSConfig = TTSConfig(),
model: Model = "s1",
):
Expand All @@ -268,6 +340,8 @@ async def stream_websocket(

Args:
text_stream: Async iterator of text chunks to stream
reference_id: Voice reference ID (overridden by config.reference_id if set)
references: Reference audio samples (overridden by config.references if set)
config: TTS configuration (audio settings, voice, model parameters)
model: TTS model to use

Expand All @@ -276,7 +350,7 @@ async def stream_websocket(

Example:
```python
from fishaudio import AsyncFishAudio, TTSConfig
from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio

client = AsyncFishAudio(api_key="...")

Expand All @@ -290,6 +364,19 @@ async def text_generator():
async for audio_chunk in client.tts.stream_websocket(text_generator()):
await f.write(audio_chunk)

# With reference_id parameter
async with aiofiles.open("output.mp3", "wb") as f:
async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
await f.write(audio_chunk)

# With references parameter
async with aiofiles.open("output.mp3", "wb") as f:
async for audio_chunk in client.tts.stream_websocket(
text_generator(),
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
):
await f.write(audio_chunk)

# Custom configuration
config = TTSConfig(format="wav", latency="normal")
async with aiofiles.open("output.wav", "wb") as f:
Expand All @@ -300,6 +387,14 @@ async def text_generator():
# Build TTSRequest from config
tts_request = _config_to_tts_request(config, text="")

# Use parameter reference_id only if config doesn't have one
if tts_request.reference_id is None and reference_id is not None:
tts_request.reference_id = reference_id

# Use parameter references only if config doesn't have any
if not tts_request.references and references:
tts_request.references = references

ws: AsyncWebSocketSession
async with aconnect_ws(
"/v1/tts/live",
Expand Down
Loading
Loading