diff --git a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py index 8577a1d5fe6..719e4201e88 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py @@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer): audio_list = serializers.ListField(required=True, label=_("The audio file cannot be empty")) + model_params_setting = serializers.DictField(required=False, + label=_("Model parameter settings")) class ISpeechToTextNode(INode): @@ -35,6 +37,6 @@ def _run(self): return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, stt_model_id, chat_id, - audio, + audio, model_params_setting=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index 7912873c98a..613599d0ad2 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -20,9 +20,9 @@ def save_context(self, details, workflow_manage): if self.node_params.get('is_result', False): self.answer_text = details.get('answer') - def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: + def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult: workspace_id = self.workflow_manage.get_body().get('workspace_id') - stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id) + stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting) audio_list = audio self.context['audio_list'] = audio diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index 9fb16f72234..7560c47c867 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -965,7 +965,7 @@ def speech_to_text(self, instance, debug=True, with_valid=True): application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by( '-create_time').first() if application.stt_model_enable: - model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id) + model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting) text = model.speech_to_text(instance.get('file')) return text diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index 48ec142a637..0db4c4b8de5 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image" msgstr "" msgid "Update auth setting" +msgstr "" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "" + +msgid "If not passed, the default value is 16000" +msgstr "" + +msgid "Sample Rate" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 62885235e29..d6a4d06fec5 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image" msgstr "获取图片失败" msgid "Update auth setting" -msgstr "更新认证设置" \ No newline at end of file +msgstr "更新认证设置" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "如果未传入,则默认值为 streaming_asr_demo" + +msgid "If not passed, the default value is 16000" +msgstr "如果未传入,则默认值为 16000" + +msgid "Sample Rate" +msgstr "采样率" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index d3bdca50a25..9952bbc3f26 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image" msgstr "獲取圖片失敗" msgid "Update auth setting" -msgstr "更新認證設置" \ No newline at end of file +msgstr "更新認證設置" + +msgid "If not passed, the default value is streaming_asr_demo" +msgstr "如果未傳入,則預設值為 streaming_asr_demo" + +msgid "If not passed, the default value is 16000" +msgstr "如果未傳入,則預設值為 16000" + +msgid "Sample Rate" +msgstr "採樣率" \ No newline at end of file diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py index a071f66a738..a6ee939127c 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py @@ -4,11 +4,21 @@ from typing import Dict, Any from django.utils.translation import gettext as _ + +from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm, PasswordInputField +from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class AliyunBaiLianSTTModelParams(BaseForm): + sample_rate = forms.SliderField( + TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')), + required=True, + default_value=16000, + _step=4000, _min=0, _max=20000,precision=0 + ) + class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): """ Credential class for the Aliyun BaiLian STT (Speech-to-Text) model. @@ -55,7 +65,7 @@ def is_valid( return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -89,4 +99,4 @@ def get_model_params_setting_form(self, model_name: str): :param model_name: Name of the model. :return: Parameter setting form (not implemented). """ - pass + return AliyunBaiLianSTTModelParams() diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py index 7017caf79db..ece41f3dd08 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py @@ -13,11 +13,14 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.model = kwargs.get('model') + self.params = kwargs.get('params') + @staticmethod def is_cache_model(): @@ -33,6 +36,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianSpeechToText( model=model_name, api_key=model_credential.get('api_key'), + params=model_kwargs, **optional_params, ) @@ -43,10 +47,17 @@ def check_auth(self): def speech_to_text(self, audio_file): dashscope.api_key = self.api_key - recognition = Recognition(model=self.model, - format='mp3', - sample_rate=16000, - callback=None) + recognition_params = { + 'model': self.model, + 'format': 'mp3', + 'sample_rate': 16000, + 'callback': None, + **self.params + } + print(recognition_params) + recognition = Recognition(**recognition_params) + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: # 将上传的文件保存到临时文件中 temp_file.write(audio_file.read()) diff --git a/apps/models_provider/impl/azure_model_provider/credential/stt.py b/apps/models_provider/impl/azure_model_provider/credential/stt.py index 7150783349d..cd115473f6e 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/stt.py +++ b/apps/models_provider/impl/azure_model_provider/credential/stt.py @@ -29,7 +29,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/azure_model_provider/model/stt.py b/apps/models_provider/impl/azure_model_provider/model/stt.py index 53f82e72fa3..c6364f37328 100644 --- a/apps/models_provider/impl/azure_model_provider/model/stt.py +++ b/apps/models_provider/impl/azure_model_provider/model/stt.py @@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key: str api_version: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') self.api_version = kwargs.get('api_version') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -41,6 +43,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), api_version=model_credential.get('api_version'), + params=model_kwargs, **optional_params, ) @@ -62,5 +65,13 @@ def speech_to_text(self, audio_file): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/openai_model_provider/credential/stt.py b/apps/models_provider/impl/openai_model_provider/credential/stt.py index 6a1dd847450..b70785bc6aa 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/stt.py +++ b/apps/models_provider/impl/openai_model_provider/credential/stt.py @@ -6,10 +6,17 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAISTTModelParams(BaseForm): + language = forms.TextInputField( + TooltipLabel(_('language'), _('If not passed, the default value is zh')), + required=True, + default_value='zh', + ) + class OpenAISTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=True) @@ -28,7 +35,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -46,4 +53,5 @@ def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + + return OpenAISTTModelParams() diff --git a/apps/models_provider/impl/openai_model_provider/model/stt.py b/apps/models_provider/impl/openai_model_provider/model/stt.py index 6df1dff0ac5..32999855631 100644 --- a/apps/models_provider/impl/openai_model_provider/model/stt.py +++ b/apps/models_provider/impl/openai_model_provider/model/stt.py @@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): api_base: str api_key: str model: str + params: dict @staticmethod def is_cache_model(): @@ -27,6 +28,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): @@ -39,6 +42,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), + params = model_kwargs, **optional_params, ) @@ -58,6 +62,14 @@ def speech_to_text(self, audio_file): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py index 6ce4e87912c..13e9cbe0ed5 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py @@ -28,7 +28,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py b/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py index c946ed39c88..b5eb1012860 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/model/stt.py @@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_base: str api_key: str model: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): @@ -35,6 +37,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, api_base=model_credential.get('api_base'), api_key=model_credential.get('api_key'), + params=model_kwargs, **optional_params, ) @@ -58,5 +61,13 @@ def speech_to_text(self, audio_file): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh' + } + + res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params) return res.text diff --git a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py index f57c046e15c..922d934a8d8 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py +++ b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py @@ -53,11 +53,14 @@ def speech_to_text(self, audio_file): base_url=base_url ) + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': audio_file, + 'language': 'zh', + } result = client.audio.transcriptions.create( - file=audio_file, - model=self.model, - language=self.params.get('Language'), - response_format="json" + **transcription_params, extra_body=filter_params ) return result.text diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py index f7e9ecc87e2..12c18325fa5 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -6,9 +6,17 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class VolcanicEngineSTTModelParams(BaseForm): + uid = forms.TextInputField( + TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')), + required=True, + default_value='streaming_asr_demo' + ) + + class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): volcanic_api_url = forms.TextInputField('API URL', required=True, @@ -31,7 +39,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -49,4 +57,4 @@ def encryption_dict(self, model: Dict[str, object]): return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} def get_model_params_setting_form(self, model_name): - pass + return VolcanicEngineSTTModelParams() diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py index 5b17919521a..bc0e5128f49 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -192,6 +192,7 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): volcanic_cluster: str volcanic_api_url: str volcanic_token: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -199,6 +200,7 @@ def __init__(self, **kwargs): self.volcanic_token = kwargs.get('volcanic_token') self.volcanic_app_id = kwargs.get('volcanic_app_id') self.volcanic_cluster = kwargs.get('volcanic_cluster') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -216,10 +218,14 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** volcanic_token=model_credential.get('volcanic_token'), volcanic_app_id=model_credential.get('volcanic_app_id'), volcanic_cluster=model_credential.get('volcanic_cluster'), + params=model_kwargs, + **model_kwargs, **optional_params ) def construct_request(self, reqid): + + params = self.params or {} req = { 'app': { 'appid': self.volcanic_app_id, @@ -227,24 +233,24 @@ def construct_request(self, reqid): 'token': self.volcanic_token, }, 'user': { - 'uid': 'uid' + 'uid': params.get("uid", "streaming_asr_demo") }, 'request': { 'reqid': reqid, - 'nbest': self.nbest, - 'workflow': self.workflow, - 'show_language': self.show_language, - 'show_utterances': self.show_utterances, - 'result_type': self.result_type, - "sequence": 1 + 'nbest': params.get('nbest', self.nbest), + 'workflow': params.get('workflow', self.workflow), + 'show_language': params.get('show_language', self.show_language), + 'show_utterances': params.get('show_utterances', self.show_utterances), + 'result_type': params.get('result_type', self.result_type), + 'sequence': params.get('sequence', 1) }, 'audio': { - 'format': self.format, - 'rate': self.rate, - 'language': self.language, - 'bits': self.bits, - 'channel': self.channel, - 'codec': self.codec + 'format': params.get('format', self.format), + 'rate': params.get('rate', self.rate), + 'language': params.get('language', self.language), + 'bits': params.get('bits', self.bits), + 'channel': params.get('channel', self.channel), + 'codec': params.get('codec', self.codec) } } return req diff --git a/apps/models_provider/impl/xf_model_provider/credential/stt.py b/apps/models_provider/impl/xf_model_provider/credential/stt.py index 56d697b36e9..67da706ba22 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/stt.py +++ b/apps/models_provider/impl/xf_model_provider/credential/stt.py @@ -6,10 +6,28 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class XunFeiSTTModelParams(BaseForm): + language = forms.TextInputField( + TooltipLabel(_('language'), _('If not passed, the default value is zh_cn')), + required=True, + default_value='zh_cn' + ) + domain = forms.TextInputField( + TooltipLabel(_('domain'), _('If not passed, the default value is iat')), + required=True, + default_value='iat' + ) + accent = forms.TextInputField( + TooltipLabel(_('accent'), _('If not passed, the default value is mandarin')), + required=True, + default_value='mandarin' + ) + + class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') spark_app_id = forms.TextInputField('APP ID', required=True) @@ -30,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: traceback.print_exc() @@ -48,4 +66,4 @@ def encryption_dict(self, model: Dict[str, object]): return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} def get_model_params_setting_form(self, model_name): - pass + return XunFeiSTTModelParams() diff --git a/apps/models_provider/impl/xf_model_provider/model/stt.py b/apps/models_provider/impl/xf_model_provider/model/stt.py index 09f011f577c..b43320746e3 100644 --- a/apps/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/models_provider/impl/xf_model_provider/model/stt.py @@ -34,6 +34,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_api_key: str spark_api_secret: str spark_api_url: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -41,6 +42,7 @@ def __init__(self, **kwargs): self.spark_app_id = kwargs.get('spark_app_id') self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_secret = kwargs.get('spark_api_secret') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -58,6 +60,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), spark_api_url=model_credential.get('spark_api_url'), + params=model_kwargs, **optional_params ) @@ -132,6 +135,11 @@ async def send(self, ws, file): frameSize = 8000 # 每一帧的音频大小 status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt', + 'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'} + + business_params = {k: v for k,v in self.params.items() if k in allowed_params} + while True: buf = file.read(frameSize) # 文件结束 @@ -144,17 +152,14 @@ async def send(self, ws, file): d = { "common": {"app_id": self.spark_app_id}, "business": { - "domain": "iat", - "language": "zh_cn", - "accent": "mandarin", - "vinfo": 1, - "vad_eos": 10000 + **business_params }, "data": { "status": 0, "format": "audio/L16;rate=16000", "audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "lame"} } + print(d) d = json.dumps(d) await ws.send(d) status = STATUS_CONTINUE_FRAME diff --git a/apps/models_provider/impl/xinference_model_provider/model/stt.py b/apps/models_provider/impl/xinference_model_provider/model/stt.py index 1994cd8fd33..e614d42d4d7 100644 --- a/apps/models_provider/impl/xinference_model_provider/model/stt.py +++ b/apps/models_provider/impl/xinference_model_provider/model/stt.py @@ -22,6 +22,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -57,5 +59,14 @@ def speech_to_text(self, audio_file): audio_data = audio_file.read() buffer = io.BytesIO(audio_data) buffer.name = "file.mp3" # this is the important line - res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + + filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} + transcription_params = { + 'model': self.model, + 'file': buffer, + 'language': 'zh', + **filter_params + } + + res = client.audio.transcriptions.create(**transcription_params) return res.text