Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/fishaudio/resources/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def _config_to_tts_request(config: TTSConfig, text: str) -> TTSRequest:
prosody=config.prosody,
top_p=config.top_p,
temperature=config.temperature,
max_new_tokens=config.max_new_tokens,
repetition_penalty=config.repetition_penalty,
min_chunk_length=config.min_chunk_length,
condition_on_previous_chunks=config.condition_on_previous_chunks,
early_stop_threshold=config.early_stop_threshold,
)


Expand Down
2 changes: 1 addition & 1 deletion src/fishaudio/types/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PaginatedResponse(BaseModel, Generic[T]):
AudioFormat = Literal["wav", "pcm", "mp3", "opus"]

# Visibility types
Visibility = Literal["public", "unlist", "private"]
Visibility = Literal["public", "unlisted", "private"]

# Training mode types
TrainMode = Literal["fast"]
Expand Down
22 changes: 22 additions & 0 deletions src/fishaudio/types/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class TTSConfig(BaseModel):
top_p: Nucleus sampling parameter for token selection. Range: 0.0-1.0. Default: 0.7
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7.
Higher = more varied, lower = more consistent
max_new_tokens: Maximum number of tokens to generate. Default: 1024
repetition_penalty: Penalty for repeated tokens. Default: 1.2
min_chunk_length: Minimum chunk length for generation. Default: 50
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
early_stop_threshold: Threshold for early stopping. Default: 1.0
"""

# Audio output settings
Expand All @@ -97,6 +102,13 @@ class TTSConfig(BaseModel):
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7

# Advanced generation parameters
max_new_tokens: int = 1024
repetition_penalty: float = 1.2
min_chunk_length: int = 50
condition_on_previous_chunks: bool = True
early_stop_threshold: float = 1.0


class TTSRequest(BaseModel):
"""
Expand All @@ -119,6 +131,11 @@ class TTSRequest(BaseModel):
prosody: Speech speed and volume settings. Default: None
top_p: Nucleus sampling for token selection. Range: 0.0-1.0. Default: 0.7
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7
max_new_tokens: Maximum number of tokens to generate. Default: 1024
repetition_penalty: Penalty for repeated tokens. Default: 1.2
min_chunk_length: Minimum chunk length for generation. Default: 50
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
early_stop_threshold: Threshold for early stopping. Default: 1.0
"""

text: str
Expand All @@ -134,6 +151,11 @@ class TTSRequest(BaseModel):
prosody: Optional[Prosody] = None
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
max_new_tokens: int = 1024
repetition_penalty: float = 1.2
min_chunk_length: int = 50
condition_on_previous_chunks: bool = True
early_stop_threshold: float = 1.0


# WebSocket event types for streaming TTS
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,51 @@ def test_convert_combined_convenience_parameters(
assert payload["latency"] == "normal"
assert payload["prosody"]["speed"] == 1.3

def test_convert_with_new_advanced_parameters(
self, tts_client, mock_client_wrapper
):
"""Test TTS with new advanced generation parameters."""
mock_response = Mock()
mock_response.iter_bytes.return_value = iter([b"audio"])
mock_client_wrapper.request.return_value = mock_response

config = TTSConfig(
max_new_tokens=2048,
repetition_penalty=1.5,
min_chunk_length=100,
condition_on_previous_chunks=False,
early_stop_threshold=0.8,
)
tts_client.convert(text="Hello", config=config)

# Verify new parameters in payload
call_args = mock_client_wrapper.request.call_args
payload = ormsgpack.unpackb(call_args[1]["content"])
assert payload["max_new_tokens"] == 2048
assert payload["repetition_penalty"] == 1.5
assert payload["min_chunk_length"] == 100
assert payload["condition_on_previous_chunks"] is False
assert payload["early_stop_threshold"] == 0.8

def test_convert_new_parameters_have_defaults(
self, tts_client, mock_client_wrapper
):
"""Test TTS default values for new advanced parameters."""
mock_response = Mock()
mock_response.iter_bytes.return_value = iter([b"audio"])
mock_client_wrapper.request.return_value = mock_response

tts_client.convert(text="Hello")

# Verify default values for new parameters in payload
call_args = mock_client_wrapper.request.call_args
payload = ormsgpack.unpackb(call_args[1]["content"])
assert payload["max_new_tokens"] == 1024
assert payload["repetition_penalty"] == 1.2
assert payload["min_chunk_length"] == 50
assert payload["condition_on_previous_chunks"] is True
assert payload["early_stop_threshold"] == 1.0


class TestAsyncTTSClient:
"""Test asynchronous AsyncTTSClient."""
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Package,
ReferenceAudio,
Prosody,
TTSConfig,
TTSRequest,
)


Expand Down Expand Up @@ -96,3 +98,78 @@ def test_prosody_custom(self):
prosody = Prosody(speed=1.5, volume=0.5)
assert prosody.speed == 1.5
assert prosody.volume == 0.5

def test_tts_config_defaults(self):
"""Test TTSConfig default values including new parameters."""
config = TTSConfig()
# Existing defaults
assert config.format == "mp3"
assert config.mp3_bitrate == 128
assert config.opus_bitrate == 32
assert config.normalize is True
assert config.chunk_length == 200
assert config.latency == "balanced"
assert config.top_p == 0.7
assert config.temperature == 0.7
# New parameter defaults
assert config.max_new_tokens == 1024
assert config.repetition_penalty == 1.2
assert config.min_chunk_length == 50
assert config.condition_on_previous_chunks is True
assert config.early_stop_threshold == 1.0

def test_tts_config_custom_new_parameters(self):
"""Test TTSConfig with custom values for new parameters."""
config = TTSConfig(
max_new_tokens=2048,
repetition_penalty=1.5,
min_chunk_length=100,
condition_on_previous_chunks=False,
early_stop_threshold=0.8,
)
assert config.max_new_tokens == 2048
assert config.repetition_penalty == 1.5
assert config.min_chunk_length == 100
assert config.condition_on_previous_chunks is False
assert config.early_stop_threshold == 0.8

def test_tts_request_defaults(self):
"""Test TTSRequest default values including new parameters."""
request = TTSRequest(text="Hello world")
# Existing defaults
assert request.text == "Hello world"
assert request.format == "mp3"
assert request.chunk_length == 200
assert request.latency == "balanced"
# New parameter defaults
assert request.max_new_tokens == 1024
assert request.repetition_penalty == 1.2
assert request.min_chunk_length == 50
assert request.condition_on_previous_chunks is True
assert request.early_stop_threshold == 1.0

def test_tts_request_custom_new_parameters(self):
"""Test TTSRequest with custom values for new parameters."""
request = TTSRequest(
text="Hello world",
max_new_tokens=512,
repetition_penalty=1.0,
min_chunk_length=25,
condition_on_previous_chunks=False,
early_stop_threshold=0.5,
)
assert request.max_new_tokens == 512
assert request.repetition_penalty == 1.0
assert request.min_chunk_length == 25
assert request.condition_on_previous_chunks is False
assert request.early_stop_threshold == 0.5


class TestVoiceVisibility:
"""Test Voice model with updated visibility."""

def test_voice_with_unlisted_visibility(self, sample_voice_response):
"""Test Voice model with 'unlisted' visibility."""
sample_voice_response["visibility"] = "unlisted"
voice = Voice.model_validate(sample_voice_response)
assert voice.visibility == "unlisted"