Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 105 additions & 72 deletions rag/llm/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,68 @@ def normalize_text(self, text):
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)


class HTTPBasedTTS(Base):
"""
Base class for HTTP-based TTS services.
Provides common HTTP request handling and response processing.
"""

def __init__(self, key, model_name, base_url, **kwargs):
self.model_name = model_name
self.base_url = base_url
self.api_key = key
self.headers = {
"Content-Type": "application/json"
}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {self.api_key}"

def _build_payload(self, text, voice, **kwargs):
"""
Build payload for TTS request.
Subclasses should override this method if they need custom payload structure.
"""
return {
"model": self.model_name,
"voice": voice,
"input": text
}

def _send_request(self, endpoint, payload, stream=True):
"""
Send HTTP request to TTS service.
"""
url = f"{self.base_url}{endpoint}"
response = requests.post(
url,
headers=self.headers,
json=payload,
stream=stream
)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

return response

def _process_response(self, response):
"""
Process streaming response from TTS service.
"""
for chunk in response.iter_content():
if chunk:
yield chunk

def tts(self, text, voice="alloy"):
"""
Generate speech from text.
"""
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/speech", payload)
return self._process_response(response)


class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"

Expand Down Expand Up @@ -178,28 +240,13 @@ def on_event(self, result: SpeechSynthesisResult):
raise RuntimeError(f"**ERROR**: {e}")


class OpenAITTS(Base):
class OpenAITTS(HTTPBasedTTS):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {"model": self.model_name, "voice": voice, "input": text}

response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
super().__init__(key, model_name, base_url)


class SparkTTS(Base):
Expand Down Expand Up @@ -291,86 +338,74 @@ def run(*args):
yield audio_chunk


class XinferenceTTS(Base):
class XinferenceTTS(HTTPBasedTTS):
_FACTORY_NAME = "Xinference"

def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
base_url = kwargs.get("base_url", None)
super().__init__(key, model_name, base_url)
# Override headers to remove Authorization
self.headers = {"accept": "application/json", "Content-Type": "application/json"}

def tts(self, text, voice="中文女", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}

response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

def _process_response(self, response):
# Use chunk_size=1024 for processing response
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

def tts(self, text, voice="中文女", stream=True):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/v1/audio/speech", payload, stream=stream)
return self._process_response(response)


class OllamaTTS(Base):
class OllamaTTS(HTTPBasedTTS):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
super().__init__(key, model_name, base_url)

def tts(self, text, voice="standard-voice"):
payload = {"model": self.model_name, "voice": voice, "input": text}

response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

for chunk in response.iter_content():
if chunk:
yield chunk
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/tts", payload)
return self._process_response(response)


class GPUStackTTS(Base):
class GPUStackTTS(HTTPBasedTTS):
_FACTORY_NAME = "GPUStack"

def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}

def tts(self, text, voice="Chinese Female", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}

response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
base_url = kwargs.get("base_url", None)
super().__init__(key, model_name, base_url)
# Add accept header
self.headers["accept"] = "application/json"

def _process_response(self, response):
# Use chunk_size=1024 for processing response
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

def tts(self, text, voice="Chinese Female", stream=True):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/v1/audio/speech", payload, stream=stream)
return self._process_response(response)

class SILICONFLOWTTS(Base):

class SILICONFLOWTTS(HTTPBasedTTS):
_FACTORY_NAME = "SILICONFLOW"

def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
super().__init__(key, model_name, base_url)

def tts(self, text, voice="anna"):
text = self.normalize_text(text)
payload = {
def _build_payload(self, text, voice, **kwargs):
# Custom payload structure for SILICONFLOW
return {
"model": self.model_name,
"input": text,
"voice": f"{self.model_name}:{voice}",
Expand All @@ -381,13 +416,11 @@ def tts(self, text, voice="anna"):
"gain": 0,
}

response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
payload = self._build_payload(text, voice)
response = self._send_request("/audio/speech", payload)
return self._process_response(response)


class DeepInfraTTS(OpenAITTS):
Expand Down