|
| 1 | +from collections.abc import Iterable |
| 2 | +from typing import Optional |
| 3 | +from urllib.parse import urljoin |
| 4 | + |
| 5 | +import requests |
| 6 | + |
| 7 | +from core.model_runtime.entities.common_entities import I18nObject |
| 8 | +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType |
| 9 | +from core.model_runtime.errors.invoke import InvokeBadRequestError |
| 10 | +from core.model_runtime.errors.validate import CredentialsValidateFailedError |
| 11 | +from core.model_runtime.model_providers.__base.tts_model import TTSModel |
| 12 | +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat |
| 13 | + |
| 14 | + |
| 15 | +class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel): |
| 16 | + """ |
| 17 | + Model class for OpenAI-compatible text2speech model. |
| 18 | + """ |
| 19 | + |
| 20 | + def _invoke( |
| 21 | + self, |
| 22 | + model: str, |
| 23 | + tenant_id: str, |
| 24 | + credentials: dict, |
| 25 | + content_text: str, |
| 26 | + voice: str, |
| 27 | + user: Optional[str] = None, |
| 28 | + ) -> Iterable[bytes]: |
| 29 | + """ |
| 30 | + Invoke TTS model |
| 31 | +
|
| 32 | + :param model: model name |
| 33 | + :param tenant_id: user tenant id |
| 34 | + :param credentials: model credentials |
| 35 | + :param content_text: text content to be translated |
| 36 | + :param voice: model voice/speaker |
| 37 | + :param user: unique user id |
| 38 | + :return: audio data as bytes iterator |
| 39 | + """ |
| 40 | + # Set up headers with authentication if provided |
| 41 | + headers = {} |
| 42 | + if api_key := credentials.get("api_key"): |
| 43 | + headers["Authorization"] = f"Bearer {api_key}" |
| 44 | + |
| 45 | + # Construct endpoint URL |
| 46 | + endpoint_url = credentials.get("endpoint_url") |
| 47 | + if not endpoint_url.endswith("/"): |
| 48 | + endpoint_url += "/" |
| 49 | + endpoint_url = urljoin(endpoint_url, "audio/speech") |
| 50 | + |
| 51 | + # Get audio format from model properties |
| 52 | + audio_format = self._get_model_audio_type(model, credentials) |
| 53 | + |
| 54 | + # Split text into chunks if needed based on word limit |
| 55 | + word_limit = self._get_model_word_limit(model, credentials) |
| 56 | + sentences = self._split_text_into_sentences(content_text, word_limit) |
| 57 | + |
| 58 | + for sentence in sentences: |
| 59 | + # Prepare request payload |
| 60 | + payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format} |
| 61 | + |
| 62 | + # Make POST request |
| 63 | + response = requests.post(endpoint_url, headers=headers, json=payload, stream=True) |
| 64 | + |
| 65 | + if response.status_code != 200: |
| 66 | + raise InvokeBadRequestError(response.text) |
| 67 | + |
| 68 | + # Stream the audio data |
| 69 | + for chunk in response.iter_content(chunk_size=4096): |
| 70 | + if chunk: |
| 71 | + yield chunk |
| 72 | + |
| 73 | + def validate_credentials(self, model: str, credentials: dict) -> None: |
| 74 | + """ |
| 75 | + Validate model credentials |
| 76 | +
|
| 77 | + :param model: model name |
| 78 | + :param credentials: model credentials |
| 79 | + :return: |
| 80 | + """ |
| 81 | + try: |
| 82 | + # Get default voice for validation |
| 83 | + voice = self._get_model_default_voice(model, credentials) |
| 84 | + |
| 85 | + # Test with a simple text |
| 86 | + next( |
| 87 | + self._invoke( |
| 88 | + model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice |
| 89 | + ) |
| 90 | + ) |
| 91 | + except Exception as ex: |
| 92 | + raise CredentialsValidateFailedError(str(ex)) |
| 93 | + |
| 94 | + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: |
| 95 | + """ |
| 96 | + Get customizable model schema |
| 97 | + """ |
| 98 | + # Parse voices from comma-separated string |
| 99 | + voice_names = credentials.get("voices", "alloy").strip().split(",") |
| 100 | + voices = [] |
| 101 | + |
| 102 | + for voice in voice_names: |
| 103 | + voice = voice.strip() |
| 104 | + if not voice: |
| 105 | + continue |
| 106 | + |
| 107 | + # Use en-US for all voices |
| 108 | + voices.append( |
| 109 | + { |
| 110 | + "name": voice, |
| 111 | + "mode": voice, |
| 112 | + "language": "en-US", |
| 113 | + } |
| 114 | + ) |
| 115 | + |
| 116 | + # If no voices provided or all voices were empty strings, use 'alloy' as default |
| 117 | + if not voices: |
| 118 | + voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}] |
| 119 | + |
| 120 | + return AIModelEntity( |
| 121 | + model=model, |
| 122 | + label=I18nObject(en_US=model), |
| 123 | + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
| 124 | + model_type=ModelType.TTS, |
| 125 | + model_properties={ |
| 126 | + ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"), |
| 127 | + ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)), |
| 128 | + ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"], |
| 129 | + ModelPropertyKey.VOICES: voices, |
| 130 | + }, |
| 131 | + ) |
| 132 | + |
| 133 | + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: |
| 134 | + """ |
| 135 | + Override base get_tts_model_voices to handle customizable voices |
| 136 | + """ |
| 137 | + model_schema = self.get_customizable_model_schema(model, credentials) |
| 138 | + |
| 139 | + if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties: |
| 140 | + raise ValueError("this model does not support voice") |
| 141 | + |
| 142 | + voices = model_schema.model_properties[ModelPropertyKey.VOICES] |
| 143 | + |
| 144 | + # Always return all voices regardless of language |
| 145 | + return [{"name": d["name"], "value": d["mode"]} for d in voices] |
0 commit comments