Skip to content

Commit e0b26c1

Browse files
committed
feat: add advanced generation parameters to TTSConfig and update related tests
1 parent 4ab33a4 commit e0b26c1

File tree

5 files changed

+157
-2
lines changed

5 files changed

+157
-2
lines changed

src/fishaudio/resources/tts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def _config_to_tts_request(config: TTSConfig, text: str) -> TTSRequest:
4141
prosody=config.prosody,
4242
top_p=config.top_p,
4343
temperature=config.temperature,
44+
max_new_tokens=config.max_new_tokens,
45+
repetition_penalty=config.repetition_penalty,
46+
min_chunk_length=config.min_chunk_length,
47+
condition_on_previous_chunks=config.condition_on_previous_chunks,
48+
early_stop_threshold=config.early_stop_threshold,
4449
)
4550

4651

src/fishaudio/types/shared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ class PaginatedResponse(BaseModel, Generic[T]):
2727
AudioFormat = Literal["wav", "pcm", "mp3", "opus"]
2828

2929
# Visibility types
30-
Visibility = Literal["public", "unlist", "private"]
30+
Visibility = Literal["public", "unlisted", "private"]
3131

3232
# Training mode types
3333
TrainMode = Literal["fast"]
3434

3535
# Model state types
36-
ModelState = Literal["created", "training", "trained", "failed"]
36+
ModelState = Literal["created", "training", "trained", "failed", "ready"]
3737

3838
# Latency modes
3939
LatencyMode = Literal["normal", "balanced"]

src/fishaudio/types/tts.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class TTSConfig(BaseModel):
7575
top_p: Nucleus sampling parameter for token selection. Range: 0.0-1.0. Default: 0.7
7676
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7.
7777
Higher = more varied, lower = more consistent
78+
max_new_tokens: Maximum number of tokens to generate. Default: 1024
79+
repetition_penalty: Penalty for repeated tokens. Default: 1.2
80+
min_chunk_length: Minimum chunk length for generation. Default: 50
81+
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
82+
early_stop_threshold: Threshold for early stopping. Default: 1.0
7883
"""
7984

8085
# Audio output settings
@@ -97,6 +102,13 @@ class TTSConfig(BaseModel):
97102
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
98103
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
99104

105+
# Advanced generation parameters
106+
max_new_tokens: int = 1024
107+
repetition_penalty: float = 1.2
108+
min_chunk_length: int = 50
109+
condition_on_previous_chunks: bool = True
110+
early_stop_threshold: float = 1.0
111+
100112

101113
class TTSRequest(BaseModel):
102114
"""
@@ -119,6 +131,11 @@ class TTSRequest(BaseModel):
119131
prosody: Speech speed and volume settings. Default: None
120132
top_p: Nucleus sampling for token selection. Range: 0.0-1.0. Default: 0.7
121133
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7
134+
max_new_tokens: Maximum number of tokens to generate. Default: 1024
135+
repetition_penalty: Penalty for repeated tokens. Default: 1.2
136+
min_chunk_length: Minimum chunk length for generation. Default: 50
137+
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
138+
early_stop_threshold: Threshold for early stopping. Default: 1.0
122139
"""
123140

124141
text: str
@@ -134,6 +151,11 @@ class TTSRequest(BaseModel):
134151
prosody: Optional[Prosody] = None
135152
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
136153
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
154+
max_new_tokens: int = 1024
155+
repetition_penalty: float = 1.2
156+
min_chunk_length: int = 50
157+
condition_on_previous_chunks: bool = True
158+
early_stop_threshold: float = 1.0
137159

138160

139161
# WebSocket event types for streaming TTS

tests/unit/test_tts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,51 @@ def test_convert_combined_convenience_parameters(
418418
assert payload["latency"] == "normal"
419419
assert payload["prosody"]["speed"] == 1.3
420420

421+
def test_convert_with_new_advanced_parameters(
422+
self, tts_client, mock_client_wrapper
423+
):
424+
"""Test TTS with new advanced generation parameters."""
425+
mock_response = Mock()
426+
mock_response.iter_bytes.return_value = iter([b"audio"])
427+
mock_client_wrapper.request.return_value = mock_response
428+
429+
config = TTSConfig(
430+
max_new_tokens=2048,
431+
repetition_penalty=1.5,
432+
min_chunk_length=100,
433+
condition_on_previous_chunks=False,
434+
early_stop_threshold=0.8,
435+
)
436+
tts_client.convert(text="Hello", config=config)
437+
438+
# Verify new parameters in payload
439+
call_args = mock_client_wrapper.request.call_args
440+
payload = ormsgpack.unpackb(call_args[1]["content"])
441+
assert payload["max_new_tokens"] == 2048
442+
assert payload["repetition_penalty"] == 1.5
443+
assert payload["min_chunk_length"] == 100
444+
assert payload["condition_on_previous_chunks"] is False
445+
assert payload["early_stop_threshold"] == 0.8
446+
447+
def test_convert_new_parameters_have_defaults(
448+
self, tts_client, mock_client_wrapper
449+
):
450+
"""Test TTS default values for new advanced parameters."""
451+
mock_response = Mock()
452+
mock_response.iter_bytes.return_value = iter([b"audio"])
453+
mock_client_wrapper.request.return_value = mock_response
454+
455+
tts_client.convert(text="Hello")
456+
457+
# Verify default values for new parameters in payload
458+
call_args = mock_client_wrapper.request.call_args
459+
payload = ormsgpack.unpackb(call_args[1]["content"])
460+
assert payload["max_new_tokens"] == 1024
461+
assert payload["repetition_penalty"] == 1.2
462+
assert payload["min_chunk_length"] == 50
463+
assert payload["condition_on_previous_chunks"] is True
464+
assert payload["early_stop_threshold"] == 1.0
465+
421466

422467
class TestAsyncTTSClient:
423468
"""Test asynchronous AsyncTTSClient."""

tests/unit/test_types.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
Package,
1212
ReferenceAudio,
1313
Prosody,
14+
TTSConfig,
15+
TTSRequest,
1416
)
1517

1618

@@ -96,3 +98,84 @@ def test_prosody_custom(self):
9698
prosody = Prosody(speed=1.5, volume=0.5)
9799
assert prosody.speed == 1.5
98100
assert prosody.volume == 0.5
101+
102+
def test_tts_config_defaults(self):
103+
"""Test TTSConfig default values including new parameters."""
104+
config = TTSConfig()
105+
# Existing defaults
106+
assert config.format == "mp3"
107+
assert config.mp3_bitrate == 128
108+
assert config.opus_bitrate == 32
109+
assert config.normalize is True
110+
assert config.chunk_length == 200
111+
assert config.latency == "balanced"
112+
assert config.top_p == 0.7
113+
assert config.temperature == 0.7
114+
# New parameter defaults
115+
assert config.max_new_tokens == 1024
116+
assert config.repetition_penalty == 1.2
117+
assert config.min_chunk_length == 50
118+
assert config.condition_on_previous_chunks is True
119+
assert config.early_stop_threshold == 1.0
120+
121+
def test_tts_config_custom_new_parameters(self):
122+
"""Test TTSConfig with custom values for new parameters."""
123+
config = TTSConfig(
124+
max_new_tokens=2048,
125+
repetition_penalty=1.5,
126+
min_chunk_length=100,
127+
condition_on_previous_chunks=False,
128+
early_stop_threshold=0.8,
129+
)
130+
assert config.max_new_tokens == 2048
131+
assert config.repetition_penalty == 1.5
132+
assert config.min_chunk_length == 100
133+
assert config.condition_on_previous_chunks is False
134+
assert config.early_stop_threshold == 0.8
135+
136+
def test_tts_request_defaults(self):
137+
"""Test TTSRequest default values including new parameters."""
138+
request = TTSRequest(text="Hello world")
139+
# Existing defaults
140+
assert request.text == "Hello world"
141+
assert request.format == "mp3"
142+
assert request.chunk_length == 200
143+
assert request.latency == "balanced"
144+
# New parameter defaults
145+
assert request.max_new_tokens == 1024
146+
assert request.repetition_penalty == 1.2
147+
assert request.min_chunk_length == 50
148+
assert request.condition_on_previous_chunks is True
149+
assert request.early_stop_threshold == 1.0
150+
151+
def test_tts_request_custom_new_parameters(self):
152+
"""Test TTSRequest with custom values for new parameters."""
153+
request = TTSRequest(
154+
text="Hello world",
155+
max_new_tokens=512,
156+
repetition_penalty=1.0,
157+
min_chunk_length=25,
158+
condition_on_previous_chunks=False,
159+
early_stop_threshold=0.5,
160+
)
161+
assert request.max_new_tokens == 512
162+
assert request.repetition_penalty == 1.0
163+
assert request.min_chunk_length == 25
164+
assert request.condition_on_previous_chunks is False
165+
assert request.early_stop_threshold == 0.5
166+
167+
168+
class TestVoiceStates:
169+
"""Test Voice model with different states and visibility."""
170+
171+
def test_voice_with_ready_state(self, sample_voice_response):
172+
"""Test Voice model with 'ready' state."""
173+
sample_voice_response["state"] = "ready"
174+
voice = Voice.model_validate(sample_voice_response)
175+
assert voice.state == "ready"
176+
177+
def test_voice_with_unlisted_visibility(self, sample_voice_response):
178+
"""Test Voice model with 'unlisted' visibility."""
179+
sample_voice_response["visibility"] = "unlisted"
180+
voice = Voice.model_validate(sample_voice_response)
181+
assert voice.visibility == "unlisted"

0 commit comments

Comments
 (0)