Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions apps/models_provider/impl/xf_model_provider/credential/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger


class XunFeiTTSModelGeneralParams(BaseForm):
vcn = forms.SingleSelect(
TooltipLabel(_('Speaker'),
_('Speaker, optional value: Please go to the console to add a trial or purchase speaker. After adding, the speaker parameter value will be displayed.')),
api_version = forms.Radio('API Version', required=False, text_field='label', value_field='value',
option_list=[
{'label': '在线语音', 'value': 'online'},
{'label': '超拟人语音', 'value': 'super_humanoid'}
],
default_value='online',
provider='',
method='',
props_info={'item_style': {'display': 'none'}})

vcn_online = forms.SingleSelect(
TooltipLabel(_('Speaker'), _('Speaker selection for standard TTS service')),
required=True, default_value='xiaoyan',
text_field='value',
value_field='value',
Expand All @@ -22,7 +31,35 @@ class XunFeiTTSModelGeneralParams(BaseForm):
{'text': _('iFlytek Xiaoping'), 'value': 'aisxping'},
{'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'},
{'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'},
])
],
relation_show_field_dict={"api_version": ["online"]})

vcn_super = forms.SingleSelect(
TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')),
required=True, default_value='x5_lingxiaoxuan_flow',
text_field='value',
value_field='value',
option_list=[
{'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'},
{'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'},
{'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'},
{'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'},
{'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'},
{'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'},
{'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'},
{'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'},
{'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'},
{'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'},
{'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'},
{'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'},
{'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'},
{'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'},
{'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'},
{'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'},
{'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'},
],
relation_show_field_dict={"api_version": ["super_humanoid"]})

speed = forms.SliderField(
TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
required=True, default_value=50,
Expand All @@ -33,7 +70,21 @@ class XunFeiTTSModelGeneralParams(BaseForm):


class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts')
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
option_list=[
{'label': '在线合成', 'value': 'online'},
{'label': '超拟人合成', 'value': 'super_humanoid'}
],
default_value='online',
provider='',
method='', )

spark_api_url = forms.TextInputField('API URL', required=True,
default_value='wss://tts-api.xfyun.cn/v2/tts',
relation_show_field_dict={"api_version": ["online"]})
spark_api_url_super = forms.TextInputField('API URL', required=True,
default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6',
relation_show_field_dict={"api_version": ["super_humanoid"]})
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
Expand All @@ -45,7 +96,13 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))

for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
api_version = model_credential.get('api_version', 'online')
if api_version == 'super_humanoid':
required_keys = ['spark_api_url_super', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
else:
required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']

for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
Expand All @@ -55,7 +112,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:
Expand All @@ -71,4 +127,4 @@ def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}

def get_model_params_setting_form(self, model_name):
return XunFeiTTSModelGeneralParams()
return XunFeiTTSModelGeneralParams()
42 changes: 42 additions & 0 deletions apps/models_provider/impl/xf_model_provider/model/default_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
"""
@project: MaxKB
@Author:
@file: default_tts.py
@date:2025/12/9
@desc: 讯飞 TTS 工厂类,根据 api_version 路由到具体实现
"""
from typing import Dict

from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech


class XFSparkDefaultTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
"""讯飞 TTS 工厂类,根据 api_version 参数路由到具体实现"""

def check_auth(self):
pass

def text_to_speech(self, text):
pass

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from models_provider.impl.xf_model_provider.model.super_humanoid_tts import XFSparkSuperHumanoidTextToSpeech

api_version = model_credential.get('api_version', 'online')

if api_version == 'super_humanoid':
return XFSparkSuperHumanoidTextToSpeech.new_instance(
model_type, model_name, model_credential, **model_kwargs
)
else:
return XFSparkTextToSpeech.new_instance(
model_type, model_name, model_credential, **model_kwargs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# -*- coding:utf-8 -*-
#
# author: iflytek
#
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import asyncio
import base64
import hashlib
import hmac
import json
import ssl
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse

import websockets
from django.utils.translation import gettext as _

from common.utils.common import _remove_empty_lines
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFSparkSuperHumanoidTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
"""讯飞超拟人语音合成 (Super Humanoid TTS)"""
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params') or {}

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
spark_api_url = model_credential.get('spark_api_url_super')
vcn = model_kwargs.get('vcn_super', 'x5_lingxiaoxuan_flow')

params = {'vcn': vcn}
for k, v in model_kwargs.items():
if k not in ['model_id', 'use_local', 'streaming', 'vcn_online', 'vcn_super', 'api_version']:
params[k] = v

return XFSparkSuperHumanoidTextToSpeech(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=spark_api_url,
params=params
)

def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname

gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)

signature_origin = f"host: {host}\n"
signature_origin += f"date: {date}\n"
signature_origin += f"GET {urlparse(url).path} HTTP/1.1"

signature_sha = hmac.new(
self.spark_api_secret.encode('utf-8'),
signature_origin.encode('utf-8'),
digestmod=hashlib.sha256
).digest()

signature_sha = base64.b64encode(signature_sha).decode('utf-8')

authorization_origin = \
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'

authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8')

v = {
"authorization": authorization,
"date": date,
"host": host
}

url = url + '?' + urlencode(v)
return url

def check_auth(self):
self.text_to_speech(_('Hello'))

def text_to_speech(self, text):
text = _remove_empty_lines(text)

async def handle():
try:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
await self.send(ws, text)
return await self.handle_message(ws)
except websockets.exceptions.InvalidStatus as e:
if e.response.status_code == 401:
raise Exception(
_("Authentication failed (HTTP 401). Please check: "
"1) API URL is correct for TTS service; "
"2) APP ID, API Key, and API Secret are correct; "
"3) Your iFlytek account has TTS service enabled.")
)
else:
raise Exception(f"WebSocket connection failed: HTTP {e.response.status_code}")
except Exception as e:
if "Authentication failed" in str(e):
raise
raise Exception(f"iFlytek TTS service error: {str(e)}")

return asyncio.run(handle())

@staticmethod
async def handle_message(ws):
audio_bytes: bytes = b''
while True:
res = await ws.recv()
message = json.loads(res)

if "header" in message and "code" in message["header"]:
code = message["header"]["code"]
sid = message["header"].get("sid", "unknown")

if code != 0:
errMsg = message["header"].get("message", "Unknown error")
raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")

if "payload" in message and "audio" in message["payload"]:
audio = base64.b64decode(message["payload"]["audio"]["audio"])
audio_bytes += audio

if message["payload"]["audio"].get("status") == 2:
break
else:
raise Exception(
f"Unexpected response from iFlytek API. Response: {json.dumps(message, ensure_ascii=False)}"
)

return audio_bytes

async def send(self, ws, text):
vcn_value = self.params.get("vcn", "x5_lingxiaoxuan_flow")

# 确保 vcn 值符合超拟人格式
if not vcn_value or not (str(vcn_value).startswith('x5_') or str(vcn_value).startswith('x6_')):
vcn_value = 'x5_lingxiaoxuan_flow'

audio_params = {
"encoding": self.params.get("encoding", "lame"),
"sample_rate": self.params.get("sample_rate", 24000),
"channels": self.params.get("channels", 1),
"bit_depth": self.params.get("bit_depth", 16),
"frame_size": self.params.get("frame_size", 0)
}

tts_params = {
"vcn": vcn_value,
"audio": audio_params,
"volume": self.params.get("volume", 50),
"speed": self.params.get("speed", 50),
"pitch": self.params.get("pitch", 50)
}

encoded_text = base64.b64encode(text.encode('utf-8')).decode('utf-8')
payload_text_obj = {
"encoding": "utf8",
"compress": "raw",
"format": "plain",
"status": 2,
"seq": 0,
"text": encoded_text
}

d = {
"header": {"app_id": self.spark_app_id, "status": 2},
"parameter": {"tts": tts_params},
"payload": {"text": payload_text_obj}
}

await ws.send(json.dumps(d))
Loading