Skip to content

Commit 92aec4d

Browse files
committed
feat: add speech_to_text node and text_to_speech node
1 parent 7bd791f commit 92aec4d

File tree

36 files changed

+1063
-42
lines changed

36 files changed

+1063
-42
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
from .image_generate_step_node import *
2222

2323
from .search_dataset_node import *
24+
from .speech_to_text_step_node import BaseSpeechToTextNode
2425
from .start_node import *
26+
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
2527

2628
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
2729
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
2830
BaseDocumentExtractNode,
29-
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
31+
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,BaseImageGenerateNode]
3032

3133

3234
def get_node(node_type):

apps/application/flow/step_node/application_node/i_application_node.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ApplicationNodeSerializer(serializers.Serializer):
1414
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))
1515
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
1616
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
17+
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
1718
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
1819
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))
1920

@@ -43,19 +44,30 @@ def _run(self):
4344
app_document_list[1:])
4445
for document in app_document_list:
4546
if 'file_id' not in document:
46-
raise ValueError("参数值错误: 上传的文档中缺少file_id")
47+
raise ValueError("参数值错误: 上传的文档中缺少file_id,文档上传失败")
4748
app_image_list = self.node_params_serializer.data.get('image_list', [])
4849
if app_image_list and len(app_image_list) > 0:
4950
app_image_list = self.workflow_manage.get_reference_field(
5051
app_image_list[0],
5152
app_image_list[1:])
5253
for image in app_image_list:
5354
if 'file_id' not in image:
54-
raise ValueError("参数值错误: 上传的图片中缺少file_id")
55+
raise ValueError("参数值错误: 上传的图片中缺少file_id,图片上传失败")
56+
57+
app_audio_list = self.node_params_serializer.data.get('audio_list', [])
58+
if app_audio_list and len(app_audio_list) > 0:
59+
app_audio_list = self.workflow_manage.get_reference_field(
60+
app_audio_list[0],
61+
app_audio_list[1:])
62+
for audio in app_audio_list:
63+
if 'file_id' not in audio:
64+
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")
5565
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
5666
app_document_list=app_document_list, app_image_list=app_image_list,
67+
app_audio_list=app_audio_list,
5768
message=str(question), **kwargs)
5869

5970
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
60-
app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult:
71+
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
72+
**kwargs) -> NodeResult:
6173
pass

apps/application/flow/step_node/application_node/impl/base_application_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def save_context(self, details, workflow_manage):
154154
self.answer_text = details.get('answer')
155155

156156
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
157-
app_document_list=None, app_image_list=None, child_node=None, node_data=None,
157+
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
158158
**kwargs) -> NodeResult:
159159
from application.serializers.chat_message_serializers import ChatMessageSerializer
160160
# 生成嵌入应用的chat_id
@@ -167,6 +167,8 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
167167
app_document_list = []
168168
if app_image_list is None:
169169
app_image_list = []
170+
if app_audio_list is None:
171+
app_audio_list = []
170172
runtime_node_id = None
171173
record_id = None
172174
child_node_value = None
@@ -186,6 +188,7 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
186188
'client_type': client_type,
187189
'document_list': app_document_list,
188190
'image_list': app_image_list,
191+
'audio_list': app_audio_list,
189192
'runtime_node_id': runtime_node_id,
190193
'chat_record_id': record_id,
191194
'child_node': child_node_value,
@@ -234,5 +237,6 @@ def get_details(self, index: int, **kwargs):
234237
'global_fields': global_fields,
235238
'document_list': self.workflow_manage.document_list,
236239
'image_list': self.workflow_manage.image_list,
240+
'audio_list': self.workflow_manage.audio_list,
237241
'application_node_dict': self.context.get('application_node_dict')
238242
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
from common.util.field_message import ErrMessage
9+
10+
11+
class SpeechToTextNodeSerializer(serializers.Serializer):
12+
stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
13+
14+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
15+
16+
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
17+
18+
19+
class ISpeechToTextNode(INode):
20+
type = 'speech-to-text-node'
21+
22+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
23+
return SpeechToTextNodeSerializer
24+
25+
def _run(self):
26+
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
27+
self.node_params_serializer.data.get('audio_list')[1:])
28+
for audio in res:
29+
if 'file_id' not in audio:
30+
raise ValueError("参数值错误: 上传的图片中缺少file_id,音频上传失败")
31+
32+
return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
33+
34+
def execute(self, stt_model_id, chat_id,
35+
audio,
36+
**kwargs) -> NodeResult:
37+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .base_speech_to_text_node import BaseSpeechToTextNode
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# coding=utf-8
2+
import os
3+
import tempfile
4+
import time
5+
import io
6+
from typing import List, Dict
7+
8+
from django.db.models import QuerySet
9+
from pydub import AudioSegment
10+
from concurrent.futures import ThreadPoolExecutor
11+
from application.flow.i_step_node import NodeResult, INode
12+
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
13+
from common.util.common import split_and_transcribe
14+
from dataset.models import File
15+
from setting.models_provider.tools import get_model_instance_by_model_user_id
16+
17+
18+
class BaseSpeechToTextNode(ISpeechToTextNode):
19+
20+
def save_context(self, details, workflow_manage):
21+
self.context['answer'] = details.get('answer')
22+
self.answer_text = details.get('answer')
23+
24+
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
25+
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
26+
audio_list = audio
27+
self.context['audio_list'] = audio
28+
29+
30+
def process_audio_item(audio_item, model):
31+
file = QuerySet(File).filter(id=audio_item['file_id']).first()
32+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file:
33+
temp_file.write(file.get_byte().tobytes())
34+
temp_file_path = temp_file.name
35+
try:
36+
return split_and_transcribe(temp_file_path, model)
37+
finally:
38+
os.remove(temp_file_path)
39+
40+
def process_audio_items(audio_list, model):
41+
with ThreadPoolExecutor(max_workers=5) as executor:
42+
results = list(executor.map(lambda item: process_audio_item(item, model), audio_list))
43+
return '\n\n'.join(results)
44+
45+
result = process_audio_items(audio_list, stt_model)
46+
return NodeResult({'answer': result, 'result': result}, {})
47+
48+
def get_details(self, index: int, **kwargs):
49+
return {
50+
'name': self.node.properties.get('stepName'),
51+
"index": index,
52+
'run_time': self.context.get('run_time'),
53+
'answer': self.context.get('answer'),
54+
'type': self.node.type,
55+
'status': self.status,
56+
'err_message': self.err_message,
57+
'audio_list': self.context.get('audio_list'),
58+
}

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def save_context(self, details, workflow_manage):
3939
self.context['run_time'] = details.get('run_time')
4040
self.context['document'] = details.get('document_list')
4141
self.context['image'] = details.get('image_list')
42+
self.context['audio'] = details.get('audio_list')
4243
self.status = details.get('status')
4344
self.err_message = details.get('err_message')
4445
for key, value in workflow_variable.items():
@@ -57,7 +58,8 @@ def execute(self, question, **kwargs) -> NodeResult:
5758
node_variable = {
5859
'question': question,
5960
'image': self.workflow_manage.image_list,
60-
'document': self.workflow_manage.document_list
61+
'document': self.workflow_manage.document_list,
62+
'audio': self.workflow_manage.audio_list
6163
}
6264
return NodeResult(node_variable, workflow_variable)
6365

@@ -80,5 +82,6 @@ def get_details(self, index: int, **kwargs):
8082
'err_message': self.err_message,
8183
'image_list': self.context.get('image'),
8284
'document_list': self.context.get('document'),
85+
'audio_list': self.context.get('audio'),
8386
'global_fields': global_fields
8487
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# coding=utf-8
2+
3+
from .impl import *
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
from common.util.field_message import ErrMessage
9+
10+
11+
class TextToSpeechNodeSerializer(serializers.Serializer):
12+
tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
13+
14+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
15+
16+
content_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文本内容"))
17+
model_params_setting = serializers.DictField(required=False,
18+
error_messages=ErrMessage.integer("模型参数相关设置"))
19+
20+
21+
class ITextToSpeechNode(INode):
22+
type = 'text-to-speech-node'
23+
24+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
25+
return TextToSpeechNodeSerializer
26+
27+
def _run(self):
28+
content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0],
29+
self.node_params_serializer.data.get('content_list')[1:])
30+
return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data)
31+
32+
def execute(self, tts_model_id, chat_id,
33+
content, model_params_setting=None,
34+
**kwargs) -> NodeResult:
35+
pass

0 commit comments

Comments
 (0)