diff --git a/apps/models_provider/impl/xf_model_provider/credential/tts.py b/apps/models_provider/impl/xf_model_provider/credential/tts.py index 121c7be919c..d8700beebbe 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/tts.py +++ b/apps/models_provider/impl/xf_model_provider/credential/tts.py @@ -7,12 +7,21 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode -from common.utils.logger import maxkb_logger + class XunFeiTTSModelGeneralParams(BaseForm): - vcn = forms.SingleSelect( - TooltipLabel(_('Speaker'), - _('Speaker, optional value: Please go to the console to add a trial or purchase speaker. After adding, the speaker parameter value will be displayed.')), + api_version = forms.Radio('API Version', required=False, text_field='label', value_field='value', + option_list=[ + {'label': '在线语音', 'value': 'online'}, + {'label': '超拟人语音', 'value': 'super_humanoid'} + ], + default_value='online', + provider='', + method='', + props_info={'item_style': {'display': 'none'}}) + + vcn_online = forms.SingleSelect( + TooltipLabel(_('Speaker'), _('Speaker selection for standard TTS service')), required=True, default_value='xiaoyan', text_field='value', value_field='value', @@ -22,7 +31,35 @@ class XunFeiTTSModelGeneralParams(BaseForm): {'text': _('iFlytek Xiaoping'), 'value': 'aisxping'}, {'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'}, {'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'}, - ]) + ], + relation_show_field_dict={"api_version": ["online"]}) + + vcn_super = forms.SingleSelect( + TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')), + required=True, default_value='x5_lingxiaoxuan_flow', + text_field='value', + value_field='value', + option_list=[ + {'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'}, + {'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'}, + {'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'}, + {'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'}, + {'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'}, + {'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'}, + {'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'}, + {'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'}, + {'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'}, + {'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'}, + {'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'}, + {'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'}, + {'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'}, + {'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'}, + {'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'}, + {'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'}, + {'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'}, + ], + relation_show_field_dict={"api_version": ["super_humanoid"]}) + speed = forms.SliderField( TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')), required=True, default_value=50, @@ -33,7 +70,21 @@ class XunFeiTTSModelGeneralParams(BaseForm): class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') + api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value', + option_list=[ + {'label': '在线合成', 'value': 'online'}, + {'label': '超拟人合成', 'value': 'super_humanoid'} + ], + default_value='online', + provider='', + method='', ) + + spark_api_url = forms.TextInputField('API URL', required=True, + default_value='wss://tts-api.xfyun.cn/v2/tts', + relation_show_field_dict={"api_version": ["online"]}) + spark_api_url_super = forms.TextInputField('API URL', required=True, + default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6', + relation_show_field_dict={"api_version": ["super_humanoid"]}) spark_app_id = forms.TextInputField('APP ID', required=True) spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) @@ -45,7 +96,13 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje raise AppApiException(ValidCode.valid_error.value, gettext('{model_type} Model type is not supported').format(model_type=model_type)) - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + api_version = model_credential.get('api_version', 'online') + if api_version == 'super_humanoid': + required_keys = ['spark_api_url_super', 'spark_app_id', 'spark_api_key', 'spark_api_secret'] + else: + required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret'] + + for key in required_keys: if key not in model_credential: if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) @@ -55,7 +112,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/model/tts.py b/apps/models_provider/impl/xf_model_provider/model/tts.py index 8a6132b8f44..16b5c485a5e 100644 --- a/apps/models_provider/impl/xf_model_provider/model/tts.py +++ b/apps/models_provider/impl/xf_model_provider/model/tts.py @@ -4,13 +4,12 @@ # # 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# -*- coding:utf-8 -*- import asyncio import base64 -import datetime import hashlib import hmac import json -import logging import ssl from datetime import datetime, UTC from typing import Dict @@ -23,10 +22,6 @@ from models_provider.base_model_provider import MaxKBBaseModel from models_provider.impl.base_tts import BaseTextToSpeech -STATUS_FIRST_FRAME = 0 # 第一帧的标识 -STATUS_CONTINUE_FRAME = 1 # 中间帧标识 -STATUS_LAST_FRAME = 2 # 最后一帧的标识 - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -45,7 +40,7 @@ def __init__(self, **kwargs): self.spark_app_id = kwargs.get('spark_app_id') self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_secret = kwargs.get('spark_api_secret') - self.params = kwargs.get('params') + self.params = kwargs.get('params') or {} @staticmethod def is_cache_model(): @@ -53,100 +48,201 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}} - for key, value in model_kwargs.items(): - if key not in ['model_id', 'use_local', 'streaming']: - optional_params['params'][key] = value + api_version = model_credential.get('api_version', 'online') + + if api_version == 'super_humanoid': + spark_api_url = model_credential.get('spark_api_url_super') + vcn = model_kwargs.get('vcn_super', 'x5_lingxiaoxuan_flow') + else: + spark_api_url = model_credential.get('spark_api_url') + vcn = model_kwargs.get('vcn_online', 'xiaoyan') + + params = {'api_version': api_version, 'vcn': vcn} + for k, v in model_kwargs.items(): + if k not in ['model_id', 'use_local', 'streaming', 'vcn_online', 'vcn_super', 'api_version']: + params[k] = v + return XFSparkTextToSpeech( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - **optional_params + spark_api_url=spark_api_url, + params=params ) - # 生成url def create_url(self): url = self.spark_api_url host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' date = datetime.now(UTC).strftime(gmt_format) - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 + signature_origin = f"host: {host}\n" + signature_origin += f"date: {date}\n" + signature_origin += f"GET {urlparse(url).path} HTTP/1.1" + + signature_sha = hmac.new( + self.spark_api_secret.encode('utf-8'), + signature_origin.encode('utf-8'), + digestmod=hashlib.sha256 + ).digest() + + signature_sha = base64.b64encode(signature_sha).decode('utf-8') + + authorization_origin = \ + f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' + + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8') + v = { "authorization": authorization, "date": date, "host": host } - # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) return url def check_auth(self): self.text_to_speech(_('Hello')) def text_to_speech(self, text): - - # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” - # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} text = _remove_empty_lines(text) async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, text) - return await self.handle_message(ws) + try: + async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: + await self.send(ws, text) + return await self.handle_message(ws) + except websockets.exceptions.InvalidStatus as e: + if e.response.status_code == 401: + raise Exception( + _("Authentication failed (HTTP 401). Please check: " + "1) API URL is correct for TTS service; " + "2) APP ID, API Key, and API Secret are correct; " + "3) Your iFlytek account has TTS service enabled.") + ) + else: + raise Exception(f"WebSocket connection failed: HTTP {e.response.status_code}") + except Exception as e: + if "Authentication failed" in str(e): + raise + raise Exception(f"iFlytek TTS service error: {str(e)}") return asyncio.run(handle()) - def is_cache_model(self): - return False - @staticmethod async def handle_message(ws): audio_bytes: bytes = b'' while True: res = await ws.recv() message = json.loads(res) - # print(message) - code = message["code"] - sid = message["sid"] - if code != 0: - errMsg = message["message"] - raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") - else: - audio = message["data"]["audio"] - audio = base64.b64decode(audio) + if "header" in message and "code" in message["header"]: + code = message["header"]["code"] + sid = message["header"].get("sid", "unknown") + + if code != 0: + errMsg = message["header"].get("message", "Unknown error") + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") + + if "payload" in message and "audio" in message["payload"]: + audio = base64.b64decode(message["payload"]["audio"]["audio"]) + audio_bytes += audio + + if message["payload"]["audio"].get("status") == 2: + break + + elif "code" in message: + code = message["code"] + sid = message.get("sid", "unknown") + + if code != 0: + errMsg = message.get("message", "Unknown error") + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") + + audio = base64.b64decode(message["data"]["audio"]) audio_bytes += audio - # 退出 - if message["data"]["status"] == 2: - break + + if message["data"]["status"] == 2: + break + + else: + raise Exception( + f"Unexpected response from iFlytek API. Response: {json.dumps(message, ensure_ascii=False)}" + ) + return audio_bytes async def send(self, ws, text): - business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"} - d = { - "common": {"app_id": self.spark_app_id}, - "business": business | self.params, - "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, - } - d = json.dumps(d) - await ws.send(d) + api_version = None + if self.params and 'api_version' in self.params: + api_version = self.params.get('api_version') + else: + api_version = 'super_humanoid' if '/v1/private/' in (self.spark_api_url or '') else 'online' + + vcn_value = self.params.get("vcn", None) + + if api_version == 'super_humanoid': + audio_params = { + "encoding": self.params.get("encoding", "lame"), + "sample_rate": self.params.get("sample_rate", 24000), + "channels": self.params.get("channels", 1), + "bit_depth": self.params.get("bit_depth", 16), + "frame_size": self.params.get("frame_size", 0) + } + + if not vcn_value or not (str(vcn_value).startswith('x5_') or str(vcn_value).startswith('x6_')): + vcn_value = 'x5_lingxiaoxuan_flow' + + tts_params = { + "vcn": vcn_value, + "audio": audio_params, + "volume": self.params.get("volume", 50), + "speed": self.params.get("speed", 50), + "pitch": self.params.get("pitch", 50) + } + + encoded_text = base64.b64encode(text.encode('utf-8')).decode('utf-8') + payload_text_obj = { + "encoding": "utf8", + "compress": "raw", + "format": "plain", + "status": 2, + "seq": 0, + "text": encoded_text + } + + d = { + "header": {"app_id": self.spark_app_id, "status": 2}, + "parameter": {"tts": tts_params}, + "payload": {"text": payload_text_obj} + } + + await ws.send(json.dumps(d)) + + else: + encoded_text = base64.b64encode(text.encode('utf-8')).decode('utf-8') + business = { + "aue": "lame", + "sfl": 1, + "auf": "audio/L16;rate=16000", + "tte": "utf8" + } + + if "speed" in self.params: + business["speed"] = self.params.get("speed") + if "volume" in self.params: + business["volume"] = self.params.get("volume") + if "pitch" in self.params: + business["pitch"] = self.params.get("pitch") + if vcn_value: + business["vcn"] = vcn_value + + d = { + "common": {"app_id": self.spark_app_id}, + "business": business, + "data": {"status": 2, "text": encoded_text}, + } + + await ws.send(json.dumps(d)) diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index adaafb2a731..2cc124ef9c4 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -471,7 +471,22 @@ def get_model_params(self, with_valid=True): self.is_valid(raise_exception=True) model_id = self.data.get('id') model = QuerySet(Model).filter(id=model_id).first() - return model.model_params_form + model_params_form = model.model_params_form + + # 从 credential 中获取 api_version 值,并注入到 params 表单中 + # 这样前端隐藏的 api_version 字段会使用正确的值 + try: + credential = json.loads(rsa_long_decrypt(model.credential)) + api_version = credential.get('api_version') + if api_version: + for param in model_params_form: + if param.get('field') == 'api_version': + param['default_value'] = api_version + break + except Exception: + pass # 如果解密失败,使用原始的 params 表单 + + return model_params_form def save_model_params_form(self, model_params_form, with_valid=True): if with_valid: