Skip to content

Commit 4786970

Browse files
committed
feat: Support iFLYTEK large model for Chinese-English speech recognition
1 parent f9f96fd commit 4786970

File tree

3 files changed

+258
-2
lines changed

3 files changed

+258
-2
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# coding=utf-8
2+
import traceback
3+
from typing import Dict
4+
5+
from django.utils.translation import gettext as _
6+
7+
from common import forms
8+
from common.exception.app_exception import AppApiException
9+
from common.forms import BaseForm
10+
from models_provider.base_model_provider import BaseModelCredential, ValidCode
11+
12+
13+
14+
class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
15+
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1')
16+
spark_app_id = forms.TextInputField('APP ID', required=True)
17+
spark_api_key = forms.PasswordInputField("API Key", required=True)
18+
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
19+
20+
def is_valid(self,
21+
model_type: str,
22+
model_name,
23+
model_credential: Dict[str, object],
24+
model_params, provider,
25+
raise_exception=False):
26+
model_type_list = provider.get_model_type_list()
27+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
28+
raise AppApiException(ValidCode.valid_error.value,
29+
_('{model_type} Model type is not supported').format(model_type=model_type))
30+
31+
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
32+
if key not in model_credential:
33+
if raise_exception:
34+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
35+
else:
36+
return False
37+
try:
38+
model = provider.get_model(model_type, model_name, model_credential)
39+
model.check_auth()
40+
except Exception as e:
41+
traceback.print_exc()
42+
if isinstance(e, AppApiException):
43+
raise e
44+
if raise_exception:
45+
raise AppApiException(ValidCode.valid_error.value,
46+
_('Verification failed, please check whether the parameters are correct: {error}').format(
47+
error=str(e)))
48+
else:
49+
return False
50+
return True
51+
52+
def encryption_dict(self, model: Dict[str, object]):
53+
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
54+
55+
def get_model_params_setting_form(self, model_name):
56+
pass
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import asyncio
2+
import json
3+
import base64
4+
import hmac
5+
import hashlib
6+
import ssl
7+
import traceback
8+
from typing import Dict
9+
from urllib.parse import urlencode
10+
from datetime import datetime, timezone, UTC
11+
import websockets
12+
import os
13+
14+
from future.backports.urllib.parse import urlparse
15+
16+
from common.utils.logger import maxkb_logger
17+
from models_provider.base_model_provider import MaxKBBaseModel
18+
from models_provider.impl.base_stt import BaseSpeechToText
19+
20+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
21+
ssl_context.check_hostname = False
22+
ssl_context.verify_mode = ssl.CERT_NONE
23+
24+
25+
class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
26+
spark_app_id: str
27+
spark_api_key: str
28+
spark_api_secret: str
29+
spark_api_url: str
30+
31+
def __init__(self, **kwargs):
32+
super().__init__(**kwargs)
33+
self.spark_api_url = kwargs.get('spark_api_url')
34+
self.spark_app_id = kwargs.get('spark_app_id')
35+
self.spark_api_key = kwargs.get('spark_api_key')
36+
self.spark_api_secret = kwargs.get('spark_api_secret')
37+
38+
@staticmethod
39+
def is_cache_model():
40+
return False
41+
42+
@staticmethod
43+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
44+
optional_params = {}
45+
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
46+
optional_params['max_tokens'] = model_kwargs['max_tokens']
47+
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
48+
optional_params['temperature'] = model_kwargs['temperature']
49+
return XFZhEnSparkSpeechToText(
50+
spark_app_id=model_credential.get('spark_app_id'),
51+
spark_api_key=model_credential.get('spark_api_key'),
52+
spark_api_secret=model_credential.get('spark_api_secret'),
53+
spark_api_url=model_credential.get('spark_api_url'),
54+
**optional_params
55+
)
56+
57+
# 生成url
58+
def create_url(self):
59+
url = self.spark_api_url
60+
host = urlparse(url).hostname
61+
62+
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
63+
date = datetime.now(UTC).strftime(gmt_format)
64+
# 拼接字符串
65+
signature_origin = "host: " + host + "\n"
66+
signature_origin += "date: " + date + "\n"
67+
signature_origin += "GET " + "/v1 HTTP/1.1"
68+
# 进行hmac-sha256进行加密
69+
signature_sha = hmac.new(
70+
self.spark_api_secret.encode('utf-8'),
71+
signature_origin.encode('utf-8'),
72+
hashlib.sha256
73+
).digest()
74+
signature = base64.b64encode(signature_sha).decode(encoding='utf-8')
75+
76+
authorization_origin = (
77+
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", '
78+
f'headers="host date request-line", signature="{signature}"'
79+
)
80+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
81+
82+
params = {
83+
'authorization': authorization,
84+
'date': date,
85+
'host': host
86+
}
87+
auth_url = url + '?' + urlencode(params)
88+
return auth_url
89+
90+
def check_auth(self):
91+
cwd = os.path.dirname(os.path.abspath(__file__))
92+
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
93+
self.speech_to_text(f)
94+
95+
def speech_to_text(self, audio_file_path):
96+
async def handle():
97+
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
98+
# print("连接成功")
99+
# 发送音频数据
100+
await self.send_audio(ws, audio_file_path)
101+
# 接收识别结果
102+
return await self.handle_message(ws)
103+
try:
104+
return asyncio.run(handle())
105+
except Exception as err:
106+
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}")
107+
return ""
108+
109+
async def send_audio(self, ws, audio_file):
110+
"""发送音频数据"""
111+
chunk_size = 4000
112+
seq = 1
113+
max_chunks = 10000
114+
while True:
115+
chunk = audio_file.read(chunk_size)
116+
if not chunk or seq > max_chunks:
117+
break
118+
119+
chunk_base64 = base64.b64encode(chunk).decode('utf-8')
120+
# 第一帧
121+
if seq == 1:
122+
frame = {
123+
"header": {"app_id": self.spark_app_id, "status": 0},
124+
"parameter": {
125+
"iat": {
126+
"domain": "slm", "language": "zh_cn", "accent": "mandarin",
127+
"eos": 10000, "vinfo": 1,
128+
"result": {"encoding": "utf8", "compress": "raw", "format": "json"}
129+
}
130+
},
131+
"payload": {
132+
"audio": {
133+
"encoding": "lame", "sample_rate": 16000, "channels": 1,
134+
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64
135+
}
136+
}
137+
}
138+
# 中间帧
139+
else:
140+
frame = {
141+
"header": {"app_id": self.spark_app_id, "status": 1},
142+
"payload": {
143+
"audio": {
144+
"encoding": "lame", "sample_rate": 16000, "channels": 1,
145+
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64
146+
}
147+
}
148+
}
149+
150+
await ws.send(json.dumps(frame))
151+
seq += 1
152+
153+
# 发送结束帧
154+
end_frame = {
155+
"header": {"app_id": self.spark_app_id, "status": 2},
156+
"payload": {
157+
"audio": {
158+
"encoding": "lame", "sample_rate": 16000, "channels": 1,
159+
"bit_depth": 16, "seq": seq, "status": 2, "audio": ""
160+
}
161+
}
162+
}
163+
await ws.send(json.dumps(end_frame))
164+
165+
166+
# 接受信息处理器
167+
async def handle_message(self, ws):
168+
result_text = ""
169+
while True:
170+
try:
171+
message = await asyncio.wait_for(ws.recv(), timeout=30.0)
172+
data = json.loads(message)
173+
174+
if data['header']['code'] != 0:
175+
raise Exception("")
176+
177+
if 'payload' in data and 'result' in data['payload']:
178+
result = data['payload']['result']
179+
text = result.get('text', '')
180+
if text:
181+
text_data = json.loads(base64.b64decode(text).decode('utf-8'))
182+
for ws_item in text_data.get('ws', []):
183+
for cw in ws_item.get('cw', []):
184+
for sw in cw.get('sw', []):
185+
result_text += sw['w']
186+
187+
if data['header'].get('status') == 2:
188+
break
189+
except asyncio.TimeoutError:
190+
break
191+
192+
return result_text

apps/models_provider/impl/xf_model_provider/xf_model_provider.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
1818
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
1919
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
20+
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
2021
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
2122
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
2223
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
@@ -25,18 +26,24 @@
2526
from maxkb.conf import PROJECT_DIR
2627
from django.utils.translation import gettext as _
2728

29+
from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText
30+
2831
ssl._create_default_https_context = ssl.create_default_context()
2932

3033
xunfei_model_credential = XunFeiLLMModelCredential()
3134
stt_model_credential = XunFeiSTTModelCredential()
35+
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
3236
image_model_credential = XunFeiImageModelCredential()
3337
tts_model_credential = XunFeiTTSModelCredential()
3438
embedding_model_credential = XFEmbeddingCredential()
3539
model_info_list = [
3640
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
3741
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
3842
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
39-
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
43+
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
44+
XFSparkSpeechToText),
45+
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
46+
XFZhEnSparkSpeechToText),
4047
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
4148
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
4249
]
@@ -47,7 +54,8 @@
4754
.append_default_model_info(
4855
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
4956
.append_default_model_info(
50-
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
57+
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
58+
XFSparkSpeechToText),
5159
)
5260
.append_default_model_info(
5361
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))

0 commit comments

Comments
 (0)