diff --git a/src/fishaudio/resources/tts.py b/src/fishaudio/resources/tts.py index ba561ab..89be523 100644 --- a/src/fishaudio/resources/tts.py +++ b/src/fishaudio/resources/tts.py @@ -41,6 +41,11 @@ def _config_to_tts_request(config: TTSConfig, text: str) -> TTSRequest: prosody=config.prosody, top_p=config.top_p, temperature=config.temperature, + max_new_tokens=config.max_new_tokens, + repetition_penalty=config.repetition_penalty, + min_chunk_length=config.min_chunk_length, + condition_on_previous_chunks=config.condition_on_previous_chunks, + early_stop_threshold=config.early_stop_threshold, ) diff --git a/src/fishaudio/types/shared.py b/src/fishaudio/types/shared.py index 1e756d9..c9da289 100644 --- a/src/fishaudio/types/shared.py +++ b/src/fishaudio/types/shared.py @@ -27,7 +27,7 @@ class PaginatedResponse(BaseModel, Generic[T]): AudioFormat = Literal["wav", "pcm", "mp3", "opus"] # Visibility types -Visibility = Literal["public", "unlist", "private"] +Visibility = Literal["public", "unlisted", "private"] # Training mode types TrainMode = Literal["fast"] diff --git a/src/fishaudio/types/tts.py b/src/fishaudio/types/tts.py index 8b0923a..00f72da 100644 --- a/src/fishaudio/types/tts.py +++ b/src/fishaudio/types/tts.py @@ -75,6 +75,11 @@ class TTSConfig(BaseModel): top_p: Nucleus sampling parameter for token selection. Range: 0.0-1.0. Default: 0.7 temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7. Higher = more varied, lower = more consistent + max_new_tokens: Maximum number of tokens to generate. Default: 1024 + repetition_penalty: Penalty for repeated tokens. Default: 1.2 + min_chunk_length: Minimum chunk length for generation. Default: 50 + condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True + early_stop_threshold: Threshold for early stopping. Default: 1.0 """ # Audio output settings @@ -97,6 +102,13 @@ class TTSConfig(BaseModel): top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + # Advanced generation parameters + max_new_tokens: int = 1024 + repetition_penalty: float = 1.2 + min_chunk_length: int = 50 + condition_on_previous_chunks: bool = True + early_stop_threshold: float = 1.0 + class TTSRequest(BaseModel): """ @@ -119,6 +131,11 @@ class TTSRequest(BaseModel): prosody: Speech speed and volume settings. Default: None top_p: Nucleus sampling for token selection. Range: 0.0-1.0. Default: 0.7 temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7 + max_new_tokens: Maximum number of tokens to generate. Default: 1024 + repetition_penalty: Penalty for repeated tokens. Default: 1.2 + min_chunk_length: Minimum chunk length for generation. Default: 50 + condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True + early_stop_threshold: Threshold for early stopping. Default: 1.0 """ text: str @@ -134,6 +151,11 @@ class TTSRequest(BaseModel): prosody: Optional[Prosody] = None top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 + max_new_tokens: int = 1024 + repetition_penalty: float = 1.2 + min_chunk_length: int = 50 + condition_on_previous_chunks: bool = True + early_stop_threshold: float = 1.0 # WebSocket event types for streaming TTS diff --git a/tests/unit/test_tts.py b/tests/unit/test_tts.py index 47bfb06..d556c2c 100644 --- a/tests/unit/test_tts.py +++ b/tests/unit/test_tts.py @@ -418,6 +418,51 @@ def test_convert_combined_convenience_parameters( assert payload["latency"] == "normal" assert payload["prosody"]["speed"] == 1.3 + def test_convert_with_new_advanced_parameters( + self, tts_client, mock_client_wrapper + ): + """Test TTS with new advanced generation parameters.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + config = TTSConfig( + max_new_tokens=2048, + repetition_penalty=1.5, + min_chunk_length=100, + condition_on_previous_chunks=False, + early_stop_threshold=0.8, + ) + tts_client.convert(text="Hello", config=config) + + # Verify new parameters in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["max_new_tokens"] == 2048 + assert payload["repetition_penalty"] == 1.5 + assert payload["min_chunk_length"] == 100 + assert payload["condition_on_previous_chunks"] is False + assert payload["early_stop_threshold"] == 0.8 + + def test_convert_new_parameters_have_defaults( + self, tts_client, mock_client_wrapper + ): + """Test TTS default values for new advanced parameters.""" + mock_response = Mock() + mock_response.iter_bytes.return_value = iter([b"audio"]) + mock_client_wrapper.request.return_value = mock_response + + tts_client.convert(text="Hello") + + # Verify default values for new parameters in payload + call_args = mock_client_wrapper.request.call_args + payload = ormsgpack.unpackb(call_args[1]["content"]) + assert payload["max_new_tokens"] == 1024 + assert payload["repetition_penalty"] == 1.2 + assert payload["min_chunk_length"] == 50 + assert payload["condition_on_previous_chunks"] is True + assert payload["early_stop_threshold"] == 1.0 + class TestAsyncTTSClient: """Test asynchronous AsyncTTSClient.""" diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index e9dee3d..f5a9f02 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -11,6 +11,8 @@ Package, ReferenceAudio, Prosody, + TTSConfig, + TTSRequest, ) @@ -96,3 +98,78 @@ def test_prosody_custom(self): prosody = Prosody(speed=1.5, volume=0.5) assert prosody.speed == 1.5 assert prosody.volume == 0.5 + + def test_tts_config_defaults(self): + """Test TTSConfig default values including new parameters.""" + config = TTSConfig() + # Existing defaults + assert config.format == "mp3" + assert config.mp3_bitrate == 128 + assert config.opus_bitrate == 32 + assert config.normalize is True + assert config.chunk_length == 200 + assert config.latency == "balanced" + assert config.top_p == 0.7 + assert config.temperature == 0.7 + # New parameter defaults + assert config.max_new_tokens == 1024 + assert config.repetition_penalty == 1.2 + assert config.min_chunk_length == 50 + assert config.condition_on_previous_chunks is True + assert config.early_stop_threshold == 1.0 + + def test_tts_config_custom_new_parameters(self): + """Test TTSConfig with custom values for new parameters.""" + config = TTSConfig( + max_new_tokens=2048, + repetition_penalty=1.5, + min_chunk_length=100, + condition_on_previous_chunks=False, + early_stop_threshold=0.8, + ) + assert config.max_new_tokens == 2048 + assert config.repetition_penalty == 1.5 + assert config.min_chunk_length == 100 + assert config.condition_on_previous_chunks is False + assert config.early_stop_threshold == 0.8 + + def test_tts_request_defaults(self): + """Test TTSRequest default values including new parameters.""" + request = TTSRequest(text="Hello world") + # Existing defaults + assert request.text == "Hello world" + assert request.format == "mp3" + assert request.chunk_length == 200 + assert request.latency == "balanced" + # New parameter defaults + assert request.max_new_tokens == 1024 + assert request.repetition_penalty == 1.2 + assert request.min_chunk_length == 50 + assert request.condition_on_previous_chunks is True + assert request.early_stop_threshold == 1.0 + + def test_tts_request_custom_new_parameters(self): + """Test TTSRequest with custom values for new parameters.""" + request = TTSRequest( + text="Hello world", + max_new_tokens=512, + repetition_penalty=1.0, + min_chunk_length=25, + condition_on_previous_chunks=False, + early_stop_threshold=0.5, + ) + assert request.max_new_tokens == 512 + assert request.repetition_penalty == 1.0 + assert request.min_chunk_length == 25 + assert request.condition_on_previous_chunks is False + assert request.early_stop_threshold == 0.5 + + +class TestVoiceVisibility: + """Test Voice model with updated visibility.""" + + def test_voice_with_unlisted_visibility(self, sample_voice_response): + """Test Voice model with 'unlisted' visibility.""" + sample_voice_response["visibility"] = "unlisted" + voice = Voice.model_validate(sample_voice_response) + assert voice.visibility == "unlisted"