Skip to content

Commit bb3b988

Browse files
committed
feat: add support for additional parameters in TTS methods (format, latency, speed)
Signed-off-by: James Ding <[email protected]>
1 parent 87a4ada commit bb3b988

File tree

3 files changed

+419
-68
lines changed

3 files changed

+419
-68
lines changed

src/fishaudio/resources/tts.py

Lines changed: 151 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
from .realtime import aiter_websocket_audio, iter_websocket_audio
1111
from ..core import AsyncClientWrapper, ClientWrapper, RequestOptions
1212
from ..types import (
13+
AudioFormat,
1314
CloseEvent,
1415
FlushEvent,
16+
LatencyMode,
1517
Model,
18+
Prosody,
1619
ReferenceAudio,
1720
StartEvent,
1821
TextEvent,
@@ -61,6 +64,9 @@ def convert(
6164
text: str,
6265
reference_id: Optional[str] = None,
6366
references: Optional[List[ReferenceAudio]] = None,
67+
format: Optional[AudioFormat] = None,
68+
latency: Optional[LatencyMode] = None,
69+
speed: Optional[float] = None,
6470
config: TTSConfig = TTSConfig(),
6571
model: Model = "s1",
6672
request_options: Optional[RequestOptions] = None,
@@ -70,8 +76,11 @@ def convert(
7076
7177
Args:
7278
text: Text to synthesize
73-
reference_id: Voice reference ID (overridden by config.reference_id if set)
74-
references: Reference audio samples (overridden by config.references if set)
79+
reference_id: Voice reference ID (overrides config.reference_id if provided)
80+
references: Reference audio samples (overrides config.references if provided)
81+
format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided)
82+
latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided)
83+
speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided)
7584
config: TTS configuration (audio settings, voice, model parameters)
7685
model: TTS model to use
7786
request_options: Request-level overrides
@@ -88,6 +97,12 @@ def convert(
8897
# Simple usage with defaults
8998
audio = client.tts.convert(text="Hello world")
9099
100+
# With format parameter
101+
audio = client.tts.convert(text="Hello world", format="wav")
102+
103+
# With speed parameter
104+
audio = client.tts.convert(text="Hello world", speed=1.5)
105+
91106
# With reference_id parameter
92107
audio = client.tts.convert(text="Hello world", reference_id="your_model_id")
93108
@@ -97,9 +112,18 @@ def convert(
97112
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
98113
)
99114
100-
# Custom configuration
101-
config = TTSConfig(format="wav", mp3_bitrate=192)
102-
audio = client.tts.convert(text="Hello world", config=config)
115+
# Combine multiple parameters
116+
audio = client.tts.convert(
117+
text="Hello world",
118+
format="wav",
119+
speed=1.2,
120+
latency="normal"
121+
)
122+
123+
# Parameters override config values
124+
config = TTSConfig(format="mp3", speed=1.0)
125+
audio = client.tts.convert(text="Hello world", format="wav", config=config)
126+
# Result: format="wav" (parameter wins)
103127
104128
with open("output.mp3", "wb") as f:
105129
for chunk in audio:
@@ -109,14 +133,22 @@ def convert(
109133
# Build request payload from config
110134
request = _config_to_tts_request(config, text)
111135

112-
# Use parameter reference_id only if config doesn't have one
113-
if request.reference_id is None and reference_id is not None:
136+
# Apply direct parameters (always override config when provided)
137+
if reference_id is not None:
114138
request.reference_id = reference_id
115139

116-
# Use parameter references only if config doesn't have any
117-
if not request.references and references:
140+
if references is not None:
118141
request.references = references
119142

143+
if format is not None:
144+
request.format = format
145+
146+
if latency is not None:
147+
request.latency = latency
148+
149+
if speed is not None:
150+
request.prosody = Prosody(speed=speed)
151+
120152
payload = request.model_dump(exclude_none=True)
121153

122154
# Make request with streaming
@@ -139,6 +171,9 @@ def stream_websocket(
139171
*,
140172
reference_id: Optional[str] = None,
141173
references: Optional[List[ReferenceAudio]] = None,
174+
format: Optional[AudioFormat] = None,
175+
latency: Optional[LatencyMode] = None,
176+
speed: Optional[float] = None,
142177
config: TTSConfig = TTSConfig(),
143178
model: Model = "s1",
144179
max_workers: int = 10,
@@ -150,8 +185,11 @@ def stream_websocket(
150185
151186
Args:
152187
text_stream: Iterator of text chunks to stream
153-
reference_id: Voice reference ID (overridden by config.reference_id if set)
154-
references: Reference audio samples (overridden by config.references if set)
188+
reference_id: Voice reference ID (overrides config.reference_id if provided)
189+
references: Reference audio samples (overrides config.references if provided)
190+
format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided)
191+
latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided)
192+
speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided)
155193
config: TTS configuration (audio settings, voice, model parameters)
156194
model: TTS model to use
157195
max_workers: ThreadPoolExecutor workers for concurrent sender
@@ -175,6 +213,15 @@ def text_generator():
175213
for audio_chunk in client.tts.stream_websocket(text_generator()):
176214
f.write(audio_chunk)
177215
216+
# With format and speed parameters
217+
with open("output.wav", "wb") as f:
218+
for audio_chunk in client.tts.stream_websocket(
219+
text_generator(),
220+
format="wav",
221+
speed=1.3
222+
):
223+
f.write(audio_chunk)
224+
178225
# With reference_id parameter
179226
with open("output.mp3", "wb") as f:
180227
for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
@@ -188,24 +235,36 @@ def text_generator():
188235
):
189236
f.write(audio_chunk)
190237
191-
# Custom configuration
192-
config = TTSConfig(format="wav", latency="normal")
238+
# Parameters override config values
239+
config = TTSConfig(format="mp3", latency="balanced")
193240
with open("output.wav", "wb") as f:
194-
for audio_chunk in client.tts.stream_websocket(text_generator(), config=config):
241+
for audio_chunk in client.tts.stream_websocket(
242+
text_generator(),
243+
format="wav", # Parameter wins
244+
config=config
245+
):
195246
f.write(audio_chunk)
196247
```
197248
"""
198249
# Build TTSRequest from config
199250
tts_request = _config_to_tts_request(config, text="")
200251

201-
# Use parameter reference_id only if config doesn't have one
202-
if tts_request.reference_id is None and reference_id is not None:
252+
# Apply direct parameters (always override config when provided)
253+
if reference_id is not None:
203254
tts_request.reference_id = reference_id
204255

205-
# Use parameter references only if config doesn't have any
206-
if not tts_request.references and references:
256+
if references is not None:
207257
tts_request.references = references
208258

259+
if format is not None:
260+
tts_request.format = format
261+
262+
if latency is not None:
263+
tts_request.latency = latency
264+
265+
if speed is not None:
266+
tts_request.prosody = Prosody(speed=speed)
267+
209268
executor = ThreadPoolExecutor(max_workers=max_workers)
210269

211270
try:
@@ -252,6 +311,9 @@ async def convert(
252311
text: str,
253312
reference_id: Optional[str] = None,
254313
references: Optional[List[ReferenceAudio]] = None,
314+
format: Optional[AudioFormat] = None,
315+
latency: Optional[LatencyMode] = None,
316+
speed: Optional[float] = None,
255317
config: TTSConfig = TTSConfig(),
256318
model: Model = "s1",
257319
request_options: Optional[RequestOptions] = None,
@@ -261,8 +323,11 @@ async def convert(
261323
262324
Args:
263325
text: Text to synthesize
264-
reference_id: Voice reference ID (overridden by config.reference_id if set)
265-
references: Reference audio samples (overridden by config.references if set)
326+
reference_id: Voice reference ID (overrides config.reference_id if provided)
327+
references: Reference audio samples (overrides config.references if provided)
328+
format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided)
329+
latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided)
330+
speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided)
266331
config: TTS configuration (audio settings, voice, model parameters)
267332
model: TTS model to use
268333
request_options: Request-level overrides
@@ -279,6 +344,12 @@ async def convert(
279344
# Simple usage with defaults
280345
audio = await client.tts.convert(text="Hello world")
281346
347+
# With format parameter
348+
audio = await client.tts.convert(text="Hello world", format="wav")
349+
350+
# With speed parameter
351+
audio = await client.tts.convert(text="Hello world", speed=1.5)
352+
282353
# With reference_id parameter
283354
audio = await client.tts.convert(text="Hello world", reference_id="your_model_id")
284355
@@ -288,9 +359,18 @@ async def convert(
288359
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
289360
)
290361
291-
# Custom configuration
292-
config = TTSConfig(format="wav", mp3_bitrate=192)
293-
audio = await client.tts.convert(text="Hello world", config=config)
362+
# Combine multiple parameters
363+
audio = await client.tts.convert(
364+
text="Hello world",
365+
format="wav",
366+
speed=1.2,
367+
latency="normal"
368+
)
369+
370+
# Parameters override config values
371+
config = TTSConfig(format="mp3", speed=1.0)
372+
audio = await client.tts.convert(text="Hello world", format="wav", config=config)
373+
# Result: format="wav" (parameter wins)
294374
295375
async with aiofiles.open("output.mp3", "wb") as f:
296376
async for chunk in audio:
@@ -300,14 +380,22 @@ async def convert(
300380
# Build request payload from config
301381
request = _config_to_tts_request(config, text)
302382

303-
# Use parameter reference_id only if config doesn't have one
304-
if request.reference_id is None and reference_id is not None:
383+
# Apply direct parameters (always override config when provided)
384+
if reference_id is not None:
305385
request.reference_id = reference_id
306386

307-
# Use parameter references only if config doesn't have any
308-
if not request.references and references:
387+
if references is not None:
309388
request.references = references
310389

390+
if format is not None:
391+
request.format = format
392+
393+
if latency is not None:
394+
request.latency = latency
395+
396+
if speed is not None:
397+
request.prosody = Prosody(speed=speed)
398+
311399
payload = request.model_dump(exclude_none=True)
312400

313401
# Make request with streaming
@@ -330,6 +418,9 @@ async def stream_websocket(
330418
*,
331419
reference_id: Optional[str] = None,
332420
references: Optional[List[ReferenceAudio]] = None,
421+
format: Optional[AudioFormat] = None,
422+
latency: Optional[LatencyMode] = None,
423+
speed: Optional[float] = None,
333424
config: TTSConfig = TTSConfig(),
334425
model: Model = "s1",
335426
):
@@ -340,8 +431,11 @@ async def stream_websocket(
340431
341432
Args:
342433
text_stream: Async iterator of text chunks to stream
343-
reference_id: Voice reference ID (overridden by config.reference_id if set)
344-
references: Reference audio samples (overridden by config.references if set)
434+
reference_id: Voice reference ID (overrides config.reference_id if provided)
435+
references: Reference audio samples (overrides config.references if provided)
436+
format: Audio format - "mp3", "wav", or "pcm" (overrides config.format if provided)
437+
latency: Latency mode - "normal" or "balanced" (overrides config.latency if provided)
438+
speed: Speech speed multiplier, e.g. 1.5 for 1.5x speed (overrides config.prosody.speed if provided)
345439
config: TTS configuration (audio settings, voice, model parameters)
346440
model: TTS model to use
347441
@@ -364,6 +458,15 @@ async def text_generator():
364458
async for audio_chunk in client.tts.stream_websocket(text_generator()):
365459
await f.write(audio_chunk)
366460
461+
# With format and speed parameters
462+
async with aiofiles.open("output.wav", "wb") as f:
463+
async for audio_chunk in client.tts.stream_websocket(
464+
text_generator(),
465+
format="wav",
466+
speed=1.3
467+
):
468+
await f.write(audio_chunk)
469+
367470
# With reference_id parameter
368471
async with aiofiles.open("output.mp3", "wb") as f:
369472
async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
@@ -377,24 +480,36 @@ async def text_generator():
377480
):
378481
await f.write(audio_chunk)
379482
380-
# Custom configuration
381-
config = TTSConfig(format="wav", latency="normal")
483+
# Parameters override config values
484+
config = TTSConfig(format="mp3", latency="balanced")
382485
async with aiofiles.open("output.wav", "wb") as f:
383-
async for audio_chunk in client.tts.stream_websocket(text_generator(), config=config):
486+
async for audio_chunk in client.tts.stream_websocket(
487+
text_generator(),
488+
format="wav", # Parameter wins
489+
config=config
490+
):
384491
await f.write(audio_chunk)
385492
```
386493
"""
387494
# Build TTSRequest from config
388495
tts_request = _config_to_tts_request(config, text="")
389496

390-
# Use parameter reference_id only if config doesn't have one
391-
if tts_request.reference_id is None and reference_id is not None:
497+
# Apply direct parameters (always override config when provided)
498+
if reference_id is not None:
392499
tts_request.reference_id = reference_id
393500

394-
# Use parameter references only if config doesn't have any
395-
if not tts_request.references and references:
501+
if references is not None:
396502
tts_request.references = references
397503

504+
if format is not None:
505+
tts_request.format = format
506+
507+
if latency is not None:
508+
tts_request.latency = latency
509+
510+
if speed is not None:
511+
tts_request.prosody = Prosody(speed=speed)
512+
398513
ws: AsyncWebSocketSession
399514
async with aconnect_ws(
400515
"/v1/tts/live",

0 commit comments

Comments
 (0)