Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/fishaudio/resources/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def _config_to_tts_request(config: TTSConfig, text: str) -> TTSRequest:
prosody=config.prosody,
top_p=config.top_p,
temperature=config.temperature,
max_new_tokens=config.max_new_tokens,
repetition_penalty=config.repetition_penalty,
min_chunk_length=config.min_chunk_length,
condition_on_previous_chunks=config.condition_on_previous_chunks,
early_stop_threshold=config.early_stop_threshold,
)


Expand Down
4 changes: 2 additions & 2 deletions src/fishaudio/types/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class PaginatedResponse(BaseModel, Generic[T]):
AudioFormat = Literal["wav", "pcm", "mp3", "opus"]

# Visibility types
Visibility = Literal["public", "unlist", "private"]
Visibility = Literal["public", "unlisted", "private"]

# Training mode types
TrainMode = Literal["fast"]

# Model state types
ModelState = Literal["created", "training", "trained", "failed"]
ModelState = Literal["created", "training", "trained", "failed", "ready"]

# Latency modes
LatencyMode = Literal["normal", "balanced"]
22 changes: 22 additions & 0 deletions src/fishaudio/types/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class TTSConfig(BaseModel):
top_p: Nucleus sampling parameter for token selection. Range: 0.0-1.0. Default: 0.7
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7.
Higher = more varied, lower = more consistent
max_new_tokens: Maximum number of tokens to generate. Default: 1024
repetition_penalty: Penalty for repeated tokens. Default: 1.2
min_chunk_length: Minimum chunk length for generation. Default: 50
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
early_stop_threshold: Threshold for early stopping. Default: 1.0
"""

# Audio output settings
Expand All @@ -97,6 +102,13 @@ class TTSConfig(BaseModel):
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7

# Advanced generation parameters
max_new_tokens: int = 1024
repetition_penalty: float = 1.2
min_chunk_length: int = 50
condition_on_previous_chunks: bool = True
early_stop_threshold: float = 1.0


class TTSRequest(BaseModel):
"""
Expand All @@ -119,6 +131,11 @@ class TTSRequest(BaseModel):
prosody: Speech speed and volume settings. Default: None
top_p: Nucleus sampling for token selection. Range: 0.0-1.0. Default: 0.7
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7
max_new_tokens: Maximum number of tokens to generate. Default: 1024
repetition_penalty: Penalty for repeated tokens. Default: 1.2
min_chunk_length: Minimum chunk length for generation. Default: 50
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
early_stop_threshold: Threshold for early stopping. Default: 1.0
"""

text: str
Expand All @@ -134,6 +151,11 @@ class TTSRequest(BaseModel):
prosody: Optional[Prosody] = None
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
max_new_tokens: int = 1024
repetition_penalty: float = 1.2
min_chunk_length: int = 50
condition_on_previous_chunks: bool = True
early_stop_threshold: float = 1.0


# WebSocket event types for streaming TTS
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,51 @@ def test_convert_combined_convenience_parameters(
assert payload["latency"] == "normal"
assert payload["prosody"]["speed"] == 1.3

def test_convert_with_new_advanced_parameters(
self, tts_client, mock_client_wrapper
):
"""Test TTS with new advanced generation parameters."""
mock_response = Mock()
mock_response.iter_bytes.return_value = iter([b"audio"])
mock_client_wrapper.request.return_value = mock_response

config = TTSConfig(
max_new_tokens=2048,
repetition_penalty=1.5,
min_chunk_length=100,
condition_on_previous_chunks=False,
early_stop_threshold=0.8,
)
tts_client.convert(text="Hello", config=config)

# Verify new parameters in payload
call_args = mock_client_wrapper.request.call_args
payload = ormsgpack.unpackb(call_args[1]["content"])
assert payload["max_new_tokens"] == 2048
assert payload["repetition_penalty"] == 1.5
assert payload["min_chunk_length"] == 100
assert payload["condition_on_previous_chunks"] is False
assert payload["early_stop_threshold"] == 0.8

def test_convert_new_parameters_have_defaults(
self, tts_client, mock_client_wrapper
):
"""Test TTS default values for new advanced parameters."""
mock_response = Mock()
mock_response.iter_bytes.return_value = iter([b"audio"])
mock_client_wrapper.request.return_value = mock_response

tts_client.convert(text="Hello")

# Verify default values for new parameters in payload
call_args = mock_client_wrapper.request.call_args
payload = ormsgpack.unpackb(call_args[1]["content"])
assert payload["max_new_tokens"] == 1024
assert payload["repetition_penalty"] == 1.2
assert payload["min_chunk_length"] == 50
assert payload["condition_on_previous_chunks"] is True
assert payload["early_stop_threshold"] == 1.0


class TestAsyncTTSClient:
"""Test asynchronous AsyncTTSClient."""
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Package,
ReferenceAudio,
Prosody,
TTSConfig,
TTSRequest,
)


Expand Down Expand Up @@ -96,3 +98,84 @@ def test_prosody_custom(self):
prosody = Prosody(speed=1.5, volume=0.5)
assert prosody.speed == 1.5
assert prosody.volume == 0.5

def test_tts_config_defaults(self):
"""Test TTSConfig default values including new parameters."""
config = TTSConfig()
# Existing defaults
assert config.format == "mp3"
assert config.mp3_bitrate == 128
assert config.opus_bitrate == 32
assert config.normalize is True
assert config.chunk_length == 200
assert config.latency == "balanced"
assert config.top_p == 0.7
assert config.temperature == 0.7
# New parameter defaults
assert config.max_new_tokens == 1024
assert config.repetition_penalty == 1.2
assert config.min_chunk_length == 50
assert config.condition_on_previous_chunks is True
assert config.early_stop_threshold == 1.0

def test_tts_config_custom_new_parameters(self):
"""Test TTSConfig with custom values for new parameters."""
config = TTSConfig(
max_new_tokens=2048,
repetition_penalty=1.5,
min_chunk_length=100,
condition_on_previous_chunks=False,
early_stop_threshold=0.8,
)
assert config.max_new_tokens == 2048
assert config.repetition_penalty == 1.5
assert config.min_chunk_length == 100
assert config.condition_on_previous_chunks is False
assert config.early_stop_threshold == 0.8

def test_tts_request_defaults(self):
"""Test TTSRequest default values including new parameters."""
request = TTSRequest(text="Hello world")
# Existing defaults
assert request.text == "Hello world"
assert request.format == "mp3"
assert request.chunk_length == 200
assert request.latency == "balanced"
# New parameter defaults
assert request.max_new_tokens == 1024
assert request.repetition_penalty == 1.2
assert request.min_chunk_length == 50
assert request.condition_on_previous_chunks is True
assert request.early_stop_threshold == 1.0

def test_tts_request_custom_new_parameters(self):
"""Test TTSRequest with custom values for new parameters."""
request = TTSRequest(
text="Hello world",
max_new_tokens=512,
repetition_penalty=1.0,
min_chunk_length=25,
condition_on_previous_chunks=False,
early_stop_threshold=0.5,
)
assert request.max_new_tokens == 512
assert request.repetition_penalty == 1.0
assert request.min_chunk_length == 25
assert request.condition_on_previous_chunks is False
assert request.early_stop_threshold == 0.5


class TestVoiceStates:
"""Test Voice model with different states and visibility."""

def test_voice_with_ready_state(self, sample_voice_response):
"""Test Voice model with 'ready' state."""
sample_voice_response["state"] = "ready"
voice = Voice.model_validate(sample_voice_response)
assert voice.state == "ready"

def test_voice_with_unlisted_visibility(self, sample_voice_response):
"""Test Voice model with 'unlisted' visibility."""
sample_voice_response["visibility"] = "unlisted"
voice = Voice.model_validate(sample_voice_response)
assert voice.visibility == "unlisted"
Loading