Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class SpeechToTextNodeSerializer(serializers.Serializer):

audio_list = serializers.ListField(required=True,
label=_("The audio file cannot be empty"))
model_params_setting = serializers.DictField(required=False,
label=_("Model parameter settings"))


class ISpeechToTextNode(INode):
Expand All @@ -35,6 +37,6 @@ def _run(self):
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, stt_model_id, chat_id,
audio,
audio, model_params_setting=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def save_context(self, details, workflow_manage):
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
def execute(self, stt_model_id, chat_id, audio, model_params_setting=None, **kwargs) -> NodeResult:
workspace_id = self.workflow_manage.get_body().get('workspace_id')
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id, **model_params_setting)
audio_list = audio
self.context['audio_list'] = audio

Expand Down
2 changes: 1 addition & 1 deletion apps/application/serializers/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def speech_to_text(self, instance, debug=True, with_valid=True):
application = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
'-create_time').first()
if application.stt_model_enable:
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id)
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id, **application.stt_model_params_setting)
text = model.speech_to_text(instance.get('file'))
return text

Expand Down
9 changes: 9 additions & 0 deletions apps/locales/en_US/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8718,4 +8718,13 @@ msgid "Failed to obtain the image"
msgstr ""

msgid "Update auth setting"
msgstr ""

msgid "If not passed, the default value is streaming_asr_demo"
msgstr ""

msgid "If not passed, the default value is 16000"
msgstr ""

msgid "Sample Rate"
msgstr ""
11 changes: 10 additions & 1 deletion apps/locales/zh_CN/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "获取图片失败"

msgid "Update auth setting"
msgstr "更新认证设置"
msgstr "更新认证设置"

msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未传入,则默认值为 streaming_asr_demo"

msgid "If not passed, the default value is 16000"
msgstr "如果未传入,则默认值为 16000"

msgid "Sample Rate"
msgstr "采样率"
11 changes: 10 additions & 1 deletion apps/locales/zh_Hant/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8844,4 +8844,13 @@ msgid "Failed to obtain the image"
msgstr "獲取圖片失敗"

msgid "Update auth setting"
msgstr "更新認證設置"
msgstr "更新認證設置"

msgid "If not passed, the default value is streaming_asr_demo"
msgstr "如果未傳入,則預設值為 streaming_asr_demo"

msgid "If not passed, the default value is 16000"
msgstr "如果未傳入,則預設值為 16000"

msgid "Sample Rate"
msgstr "採樣率"
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
from typing import Dict, Any

from django.utils.translation import gettext as _

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode


class AliyunBaiLianSTTModelParams(BaseForm):
sample_rate = forms.SliderField(
TooltipLabel(_('Sample Rate'), _('If not passed, the default value is 16000')),
required=True,
default_value=16000,
_step=4000, _min=0, _max=20000,precision=0
)

class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
Expand Down Expand Up @@ -55,7 +65,7 @@ def is_valid(
return False

try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Expand Down Expand Up @@ -89,4 +99,4 @@ def get_model_params_setting_form(self, model_name: str):
:param model_name: Name of the model.
:return: Parameter setting form (not implemented).
"""
pass
return AliyunBaiLianSTTModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
model: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')


@staticmethod
def is_cache_model():
Expand All @@ -33,6 +36,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
return AliyunBaiLianSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params,
)

Expand All @@ -43,10 +47,17 @@ def check_auth(self):

def speech_to_text(self, audio_file):
dashscope.api_key = self.api_key
recognition = Recognition(model=self.model,
format='mp3',
sample_rate=16000,
callback=None)
recognition_params = {
'model': self.model,
'format': 'mp3',
'sample_rate': 16000,
'callback': None,
**self.params
}
print(recognition_params)
recognition = Recognition(**recognition_params)


with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中
temp_file.write(audio_file.read())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be an issue in the speech_to_text method. The line:

dashscope.api_key = self.api_key

Should likely be changed to:

dashscope.API_KEY = self.api_key

Assuming that dashscope is a global variable or class attribute used throughout your script, this error would prevent it from using the appropriate API key for authentication during processing.

Additionally, you have two instances of printing the parameters dictionary (print(recognition_params)). It might not be necessary unless debugging purposes; if these prints can simply be removed without affecting functionality, they should indeed be excluded for cleaner code.

Here's the adjusted part relevant to fixing the auth issue (and optionally removing unnecessary print statements):

def speech_to_text(self, audio_file):
    recognition_params = {
        'model': self.model,
        'format': 'mp3',
        'sample_rate': 16000,
        'callback': None,
        **self.params
    }
    
    dashscope.API_KEY = self.api_key
    
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        # Ensure correct context for file handle operations
        with open(temp_file.name, "wb"):
            temp_file.write(audio_file.read())

# Optionally remove or comment out the following
# print("Recognition Parameters:", recognition_params)

This fix ensures proper authorization handling while maintaining clarity within your application code.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Expand Down
13 changes: 12 additions & 1 deletion apps/models_provider/impl/azure_model_provider/model/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
api_version: str
model: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version')
self.params = kwargs.get('params')

@staticmethod
def is_cache_model():
Expand All @@ -41,6 +43,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'),
params=model_kwargs,
**optional_params,
)

Expand All @@ -62,5 +65,13 @@ def speech_to_text(self, audio_file):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)

filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}

res = client.audio.transcriptions.create(**transcription_params, extra_body=filter_params)
return res.text
14 changes: 11 additions & 3 deletions apps/models_provider/impl/openai_model_provider/credential/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode


class OpenAISTTModelParams(BaseForm):
language = forms.TextInputField(
TooltipLabel(_('language'), _('If not passed, the default value is zh')),
required=True,
default_value='zh',
)

class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
Expand All @@ -28,7 +35,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Expand All @@ -46,4 +53,5 @@ def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass

return OpenAISTTModelParams()
14 changes: 13 additions & 1 deletion apps/models_provider/impl/openai_model_provider/model/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
model: str
params: dict

@staticmethod
def is_cache_model():
Expand All @@ -27,6 +28,8 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
Expand All @@ -39,6 +42,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
params = model_kwargs,
**optional_params,
)

Expand All @@ -58,6 +62,14 @@ def speech_to_text(self, audio_file):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)

filter_params = {k: v for k,v in self.params.items() if k not in {'model_id','use_local','streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}

res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text

Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ class SiliconCloudSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
model: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.params = kwargs.get('params')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
Expand All @@ -35,6 +37,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
params=model_kwargs,
**optional_params,
)

Expand All @@ -58,5 +61,13 @@ def speech_to_text(self, audio_file):
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)

filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': buffer,
'language': 'zh'
}

res = client.audio.transcriptions.create(**transcription_params,extra_body=filter_params)
return res.text
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ def speech_to_text(self, audio_file):
base_url=base_url
)

filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
transcription_params = {
'model': self.model,
'file': audio_file,
'language': 'zh',
}
result = client.audio.transcriptions.create(
file=audio_file,
model=self.model,
language=self.params.get('Language'),
response_format="json"
**transcription_params, extra_body=filter_params
)

return result.text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode

class VolcanicEngineSTTModelParams(BaseForm):
uid = forms.TextInputField(
TooltipLabel(_('User ID'),_('If not passed, the default value is streaming_asr_demo')),
required=True,
default_value='streaming_asr_demo'
)



class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API URL', required=True,
Expand All @@ -31,7 +39,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Expand All @@ -49,4 +57,4 @@ def encryption_dict(self, model: Dict[str, object]):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}

def get_model_params_setting_form(self, model_name):
pass
return VolcanicEngineSTTModelParams()
Loading
Loading