Skip to content

Commit a5ef985

Browse files
committed
feat: STT model params
1 parent 94823b2 commit a5ef985

File tree

20 files changed

+193
-50
lines changed

20 files changed

+193
-50
lines changed

apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer):
1616

1717
audio_list = serializers.ListField(required=True,
1818
label=_("The audio file cannot be empty"))
19+
model_params_setting = serializers.DictField(required=False,
20+
label=_("Model parameter settings"))
1921

2022

2123
class ISpeechToTextNode(INode):
@@ -35,6 +37,6 @@ def _run(self):
3537
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
3638

3739
def execute(self, stt_model_id, chat_id,
38-
audio,
40+
audio, model_params_setting=None,
3941
**kwargs) -> NodeResult:
4042
pass

apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def save_context(self, details, workflow_manage):
2020
if self.node_params.get('is_result', False):
2121
self.answer_text = details.get('answer')
2222

23-
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
23+
def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
2424
workspace_id = self.workflow_manage.get_body().get('workspace_id')
25-
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
25+
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
2626
audio_list = audio
2727
self.context['audio_list'] = audio
2828

apps/application/serializers/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def speech_to_text(self, instance, debug=True, with_valid=True):
965965
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
966966
'-create_time').first()
967967
if application.stt_model_enable:
968-
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id)
968+
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting)
969969
text = model.speech_to_text(instance.get('file'))
970970
return text
971971

apps/locales/en_US/LC_MESSAGES/django.po

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image"
87188718
msgstr ""
87198719

87208720
msgid "Update auth setting"
8721+
msgstr ""
8722+
8723+
msgid "If not passed, the default value is streaming_asr_demo"
8724+
msgstr ""
8725+
8726+
msgid "If not passed, the default value is 16000"
8727+
msgstr ""
8728+
8729+
msgid "Sample Rate"
87218730
msgstr ""

apps/locales/zh_CN/LC_MESSAGES/django.po

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
88448844
msgstr "获取图片失败"
88458845

88468846
msgid "Update auth setting"
8847-
msgstr "更新认证设置"
8847+
msgstr "更新认证设置"
8848+
8849+
msgid "If not passed, the default value is streaming_asr_demo"
8850+
msgstr "如果未传入,则默认值为 streaming_asr_demo"
8851+
8852+
msgid "If not passed, the default value is 16000"
8853+
msgstr "如果未传入,则默认值为 16000"
8854+
8855+
msgid "Sample Rate"
8856+
msgstr "采样率"

apps/locales/zh_Hant/LC_MESSAGES/django.po

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
88448844
msgstr "獲取圖片失敗"
88458845

88468846
msgid "Update auth setting"
8847-
msgstr "更新認證設置"
8847+
msgstr "更新認證設置"
8848+
8849+
msgid "If not passed, the default value is streaming_asr_demo"
8850+
msgstr "如果未傳入,則預設值為 streaming_asr_demo"
8851+
8852+
msgid "If not passed, the default value is 16000"
8853+
msgstr "如果未傳入,則預設值為 16000"
8854+
8855+
msgid "Sample Rate"
8856+
msgstr "採樣率"

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,21 @@
44
from typing import Dict, Any
55

66
from django.utils.translation import gettext as _
7+
8+
from common import forms
79
from common.exception.app_exception import AppApiException
8-
from common.forms import BaseForm, PasswordInputField
10+
from common.forms import BaseForm, PasswordInputField, TooltipLabel
911
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1012

1113

14+
class AliyunBaiLianSTTModelParams(BaseForm):
15+
sample_rate = forms.SliderField(
16+
TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')),
17+
required=True,
18+
default_value=16000,
19+
_step=4000, _min=0, _max=20000,precision=0
20+
)
21+
1222
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
1323
"""
1424
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
@@ -55,7 +65,7 @@ def is_valid(
5565
return False
5666

5767
try:
58-
model = provider.get_model(model_type, model_name, model_credential)
68+
model = provider.get_model(model_type, model_name, model_credential,**model_params)
5969
model.check_auth()
6070
except Exception as e:
6171
traceback.print_exc()
@@ -89,4 +99,4 @@ def get_model_params_setting_form(self, model_name: str):
8999
:param model_name: Name of the model.
90100
:return: Parameter setting form (not implemented).
91101
"""
92-
pass
102+
return AliyunBaiLianSTTModelParams()

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
1414
api_key: str
1515
model: str
16+
params: dict
1617

1718
def __init__(self, **kwargs):
1819
super().__init__(**kwargs)
1920
self.api_key = kwargs.get('api_key')
2021
self.model = kwargs.get('model')
22+
self.params = kwargs.get('params')
23+
2124

2225
@staticmethod
2326
def is_cache_model():
@@ -33,6 +36,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3336
return AliyunBaiLianSpeechToText(
3437
model=model_name,
3538
api_key=model_credential.get('api_key'),
39+
params=model_kwargs,
3640
**optional_params,
3741
)
3842

@@ -43,10 +47,17 @@ def check_auth(self):
4347

4448
def speech_to_text(self, audio_file):
4549
dashscope.api_key = self.api_key
46-
recognition = Recognition(model=self.model,
47-
format='mp3',
48-
sample_rate=16000,
49-
callback=None)
50+
recognition_params = {
51+
'model': self.model,
52+
'format': 'mp3',
53+
'sample_rate': 16000,
54+
'callback': None,
55+
**self.params
56+
}
57+
print(recognition_params)
58+
recognition = Recognition(**recognition_params)
59+
60+
5061
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
5162
# 将上传的文件保存到临时文件中
5263
temp_file.write(audio_file.read())

apps/models_provider/impl/azure_model_provider/credential/stt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
2929
else:
3030
return False
3131
try:
32-
model = provider.get_model(model_type, model_name, model_credential)
32+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
3333
model.check_auth()
3434
except Exception as e:
3535
traceback.print_exc()

apps/models_provider/impl/azure_model_provider/model/stt.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
1818
api_key: str
1919
api_version: str
2020
model: str
21+
params: dict
2122

2223
def __init__(self, **kwargs):
2324
super().__init__(**kwargs)
2425
self.api_key = kwargs.get('api_key')
2526
self.api_base = kwargs.get('api_base')
2627
self.api_version = kwargs.get('api_version')
28+
self.params = kwargs.get('params')
2729

2830
@staticmethod
2931
def is_cache_model():
@@ -41,6 +43,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
4143
api_base=model_credential.get('api_base'),
4244
api_key=model_credential.get('api_key'),
4345
api_version=model_credential.get('api_version'),
46+
params=model_kwargs,
4447
**optional_params,
4548
)
4649

@@ -62,5 +65,13 @@ def speech_to_text(self, audio_file):
6265
audio_data = audio_file.read()
6366
buffer = io.BytesIO(audio_data)
6467
buffer.name = "file.mp3" # this is the important line
65-
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
68+
69+
filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
70+
transcription_params = {
71+
'model': self.model,
72+
'file': buffer,
73+
'language': 'zh'
74+
}
75+
76+
res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params)
6677
return res.text

0 commit comments

Comments
 (0)