Skip to content

Commit 3d36051

Browse files
committed
feat: Application text to speech and speech to text functions
1 parent 2c0a8af commit 3d36051

File tree

16 files changed

+506
-79
lines changed

16 files changed

+506
-79
lines changed

apps/application/api/application_api.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from rest_framework import serializers
1313

1414
from application.serializers.application import ApplicationCreateSerializer, ApplicationListResponse, \
15-
ApplicationImportRequest, ApplicationEditSerializer
15+
ApplicationImportRequest, ApplicationEditSerializer, TextToSpeechRequest, SpeechToTextRequest, PlayDemoTextRequest
1616
from common.mixins.api_mixin import APIMixin
1717
from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer
1818

@@ -167,3 +167,45 @@ class ApplicationEditAPI(APIMixin):
167167
@staticmethod
168168
def get_request():
169169
return ApplicationEditSerializer
170+
171+
172+
class TextToSpeechAPI(APIMixin):
173+
@staticmethod
174+
def get_parameters():
175+
return ApplicationOperateAPI.get_parameters()
176+
177+
@staticmethod
178+
def get_request():
179+
return TextToSpeechRequest
180+
181+
@staticmethod
182+
def get_response():
183+
return DefaultResultSerializer
184+
185+
186+
class SpeechToTextAPI(APIMixin):
187+
@staticmethod
188+
def get_parameters():
189+
return ApplicationOperateAPI.get_parameters()
190+
191+
@staticmethod
192+
def get_request():
193+
return SpeechToTextRequest
194+
195+
@staticmethod
196+
def get_response():
197+
return DefaultResultSerializer
198+
199+
200+
class PlayDemoTextAPI(APIMixin):
201+
@staticmethod
202+
def get_parameters():
203+
return ApplicationOperateAPI.get_parameters()
204+
205+
@staticmethod
206+
def get_request():
207+
return PlayDemoTextRequest
208+
209+
@staticmethod
210+
def get_response():
211+
return DefaultResultSerializer

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def list_paragraph(embedding_list: List, vector):
103103
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
104104
get_file_content(
105105
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
106-
'list_dataset_paragraph_by_paragraph_id.sql')),
106+
'list_knowledge_paragraph_by_paragraph_id.sql')),
107107
with_table_name=True)
108108
# 如果向量库中存在脏数据 直接删除
109109
if len(paragraph_list) != len(paragraph_id_list):

apps/application/serializers/application.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
@date:2025/5/26 17:03
77
@desc:
88
"""
9+
import asyncio
910
import datetime
1011
import hashlib
12+
import json
1113
import os
1214
import pickle
1315
import re
@@ -19,6 +21,7 @@
1921
from django.db.models import QuerySet, Q
2022
from django.http import HttpResponse
2123
from django.utils.translation import gettext_lazy as _
24+
from langchain_mcp_adapters.client import MultiServerMCPClient
2225
from rest_framework import serializers, status
2326
from rest_framework.utils.formatting import lazy_format
2427

@@ -36,6 +39,7 @@
3639
from knowledge.serializers.knowledge import KnowledgeSerializer, KnowledgeModelSerializer
3740
from maxkb.conf import PROJECT_DIR
3841
from models_provider.models import Model
42+
from models_provider.tools import get_model_instance_by_model_workspace_id
3943
from system_manage.models import WorkspaceUserResourcePermission
4044
from tools.models import Tool, ToolScope
4145
from tools.serializers.tool import ToolModelSerializer
@@ -384,9 +388,9 @@ class ApplicationEditSerializer(serializers.Serializer):
384388
label=_("Historical chat records"))
385389
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
386390
label=_("Opening remarks"))
387-
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
388-
label=_("Related Knowledge Base")
389-
)
391+
knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
392+
label=_("Related Knowledge Base")
393+
)
390394
# 数据集相关设置
391395
knowledge_setting = KnowledgeSettingSerializer(required=False, allow_null=True,
392396
label=_("Dataset settings"))
@@ -441,8 +445,8 @@ def insert_workflow(self, instance: Dict):
441445
return ApplicationCreateSerializer.ApplicationResponse(application_model).data
442446

443447
@staticmethod
444-
def to_application_knowledge_mapping(application_id: str, dataset_id: str):
445-
return ApplicationKnowledgeMapping(id=uuid.uuid7(), application_id=application_id, dataset_id=dataset_id)
448+
def to_application_knowledge_mapping(application_id: str, knowledge_id: str):
449+
return ApplicationKnowledgeMapping(id=uuid.uuid7(), application_id=application_id, knowledge_id=knowledge_id)
446450

447451
def insert_simple(self, instance: Dict):
448452
self.is_valid(raise_exception=True)
@@ -451,10 +455,10 @@ def insert_simple(self, instance: Dict):
451455
ApplicationCreateSerializer.SimplateRequest(data=instance).is_valid(user_id=user_id, raise_exception=True)
452456
application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, workspace_id,
453457
instance)
454-
dataset_id_list = instance.get('knowledge_id_list', [])
458+
knowledge_id_list = instance.get('knowledge_id_list', [])
455459
application_knowledge_mapping_model_list = [
456-
self.to_application_knowledge_mapping(application_model.id, dataset_id) for
457-
dataset_id in dataset_id_list]
460+
self.to_application_knowledge_mapping(application_model.id, knowledge_id) for
461+
knowledge_id in knowledge_id_list]
458462
# 插入应用
459463
application_model.save()
460464
# 插入认证信息
@@ -519,15 +523,15 @@ def to_tool(tool, workspace_id, user_id):
519523
def to_application(application, workspace_id, user_id):
520524
work_flow = application.get('work_flow')
521525
for node in work_flow.get('nodes', []):
522-
if node.get('type') == 'search-dataset-node':
523-
node.get('properties', {}).get('node_data', {})['dataset_id_list'] = []
526+
if node.get('type') == 'search-knowledge-node':
527+
node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
524528
return Application(id=uuid.uuid7(),
525529
user_id=user_id,
526530
name=application.get('name'),
527531
workspace_id=workspace_id,
528532
desc=application.get('desc'),
529533
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
530-
dataset_setting=application.get('dataset_setting'),
534+
knowledge_setting=application.get('knowledge_setting'),
531535
model_setting=application.get('model_setting'),
532536
model_params_setting=application.get('model_params_setting'),
533537
tts_model_params_setting=application.get('tts_model_params_setting'),
@@ -545,6 +549,27 @@ def to_application(application, workspace_id, user_id):
545549
)
546550

547551

552+
class TextToSpeechRequest(serializers.Serializer):
553+
text = serializers.CharField(required=True, label=_('Text'))
554+
555+
556+
class SpeechToTextRequest(serializers.Serializer):
557+
file = UploadedFileField(required=True, label=_("file"))
558+
559+
560+
class PlayDemoTextRequest(serializers.Serializer):
561+
tts_model_id = serializers.UUIDField(required=True, label=_('Text to speech model ID'))
562+
563+
564+
async def get_mcp_tools(servers):
565+
async with MultiServerMCPClient(servers) as client:
566+
return client.get_tools()
567+
568+
569+
class McpServersSerializer(serializers.Serializer):
570+
mcp_servers = serializers.JSONField(required=True)
571+
572+
548573
class ApplicationOperateSerializer(serializers.Serializer):
549574
application_id = serializers.UUIDField(required=True, label=_("Application ID"))
550575
user_id = serializers.UUIDField(required=True, label=_("User ID"))
@@ -559,6 +584,23 @@ def is_valid(self, *, raise_exception=False):
559584
if not query_set.exists():
560585
raise AppApiException(500, _('Application id does not exist'))
561586

587+
def get_mcp_servers(self, instance, with_valid=True):
588+
if with_valid:
589+
self.is_valid(raise_exception=True)
590+
McpServersSerializer(data=instance).is_valid(raise_exception=True)
591+
servers = json.loads(instance.get('mcp_servers'))
592+
tools = []
593+
for server in servers:
594+
tools += [
595+
{
596+
'server': server,
597+
'name': tool.name,
598+
'description': tool.description,
599+
'args_schema': tool.args_schema,
600+
}
601+
for tool in asyncio.run(get_mcp_tools({server: servers[server]}))]
602+
return tools
603+
562604
def delete(self, with_valid=True):
563605
if with_valid:
564606
self.is_valid()
@@ -691,7 +733,7 @@ def edit(self, instance: Dict, with_valid=True):
691733
if application.type == ApplicationTypeChoices.SIMPLE.value:
692734
application.is_publish = True
693735
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
694-
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
736+
'knowledge_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
695737
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type',
696738
'tts_autoplay', 'stt_autosend', 'file_upload_enable', 'file_upload_setting',
697739
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting',
@@ -746,7 +788,7 @@ def update_knowledge_node(self, workflow, available_knowledge_dict):
746788
"""
747789
修改知识库检索节点 数据
748790
定义 all_knowledge_id_list: 所有的关联知识库
749-
dataset_id_list: 当前用户可看到的关联知识库列表
791+
knowledge_id_list: 当前用户可看到的关联知识库列表
750792
knowledge_list: 用户
751793
@param workflow: 知识库
752794
@param available_knowledge_dict: 当前用户可用的知识库
@@ -802,3 +844,35 @@ def save_application_knowledge_mapping(application_knowledge_id_list, knowledge_
802844
QuerySet(ApplicationKnowledgeMapping).bulk_create(
803845
[ApplicationKnowledgeMapping(application_id=application_id, knowledge_id=knowledge_id) for knowledge_id in
804846
knowledge_id_list]) if len(knowledge_id_list) > 0 else None
847+
848+
def speech_to_text(self, instance, with_valid=True):
849+
if with_valid:
850+
self.is_valid(raise_exception=True)
851+
SpeechToTextRequest(data=instance).is_valid(raise_exception=True)
852+
application_id = self.data.get('application_id')
853+
application = QuerySet(Application).filter(id=application_id).first()
854+
if application.stt_model_enable:
855+
model = get_model_instance_by_model_workspace_id(application.stt_model_id, application.workspace_id)
856+
text = model.speech_to_text(instance.get('file'))
857+
return text
858+
859+
def text_to_speech(self, instance, with_valid=True):
860+
if with_valid:
861+
self.is_valid(raise_exception=True)
862+
TextToSpeechRequest(data=instance).is_valid(raise_exception=True)
863+
application_id = self.data.get('application_id')
864+
application = QuerySet(Application).filter(id=application_id).first()
865+
if application.tts_model_enable:
866+
model = get_model_instance_by_model_workspace_id(application.tts_model_id, application.workspace_id,
867+
**application.tts_model_params_setting)
868+
869+
return model.text_to_speech(instance.get('text'))
870+
871+
def play_demo_text(self, instance, with_valid=True):
872+
text = '你好,这里是语音播放测试'
873+
if with_valid:
874+
self.is_valid(raise_exception=True)
875+
PlayDemoTextRequest(data=instance).is_valid(raise_exception=True)
876+
tts_model_id = instance.pop('tts_model_id')
877+
model = get_model_instance_by_model_workspace_id(tts_model_id, self.data.get('workspace_id'), **instance)
878+
return model.text_to_speech(text)

apps/application/urls.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
path('workspace/<str:workspace_id>/application/<str:application_id>/work_flow_version/<int:current_page>/<int:page_size>', views.ApplicationVersionView.Page.as_view()),
3131
path('workspace/<str:workspace_id>/application/<str:application_id>/work_flow_version/<str:work_flow_version_id>', views.ApplicationVersionView.Operate.as_view()),
3232
path('workspace/<str:workspace_id>/application/<str:application_id>/open', views.OpenView.as_view()),
33+
path('workspace/<str:workspace_id>/application/<str:application_id>/text_to_speech', views.TextToSpeech.as_view()),
34+
path('workspace/<str:workspace_id>/application/<str:application_id>/speech_to_text', views.SpeechToText.as_view()),
35+
path('workspace/<str:workspace_id>/application/<str:application_id>/play_demo_text', views.PlayDemoText.as_view()),
36+
path('workspace/<str:workspace_id>/application/<str:application_id>/mcp_tools', views.McpServers.as_view()),
3337
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
3438

3539
]

apps/application/views/application.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
@desc:
88
"""
99
from django.db.models import QuerySet
10+
from django.http import HttpResponse
1011
from django.utils.translation import gettext_lazy as _
1112
from drf_spectacular.utils import extend_schema
1213
from rest_framework.parsers import MultiPartParser
1314
from rest_framework.request import Request
1415
from rest_framework.views import APIView
1516

1617
from application.api.application_api import ApplicationCreateAPI, ApplicationQueryAPI, ApplicationImportAPI, \
17-
ApplicationExportAPI, ApplicationOperateAPI, ApplicationEditAPI
18+
ApplicationExportAPI, ApplicationOperateAPI, ApplicationEditAPI, TextToSpeechAPI, SpeechToTextAPI, PlayDemoTextAPI
1819
from application.models import Application
19-
from application.serializers.application import ApplicationSerializer, Query, ApplicationOperateSerializer
20+
from application.serializers.application import ApplicationSerializer, Query, ApplicationOperateSerializer, \
21+
McpServersSerializer
2022
from common import result
2123
from common.auth import TokenAuth
2224
from common.auth.authentication import has_permissions
23-
from common.constants.permission_constants import PermissionConstants, RoleConstants
25+
from common.constants.permission_constants import PermissionConstants, RoleConstants, CompareConstants
2426
from common.log.log import log
2527

2628

@@ -233,3 +235,101 @@ def put(self, request: Request, workspace_id: str, application_id: str):
233235
ApplicationOperateSerializer(
234236
data={'application_id': application_id, 'user_id': request.user.id,
235237
'workspace_id': workspace_id, }).publish(request.data))
238+
239+
240+
class McpServers(APIView):
241+
authentication_classes = [TokenAuth]
242+
243+
@extend_schema(
244+
methods=['GET'],
245+
description=_("speech to text"),
246+
summary=_("speech to text"),
247+
operation_id=_("speech to text"), # type: ignore
248+
parameters=SpeechToTextAPI.get_parameters(),
249+
request=SpeechToTextAPI.get_request(),
250+
responses=SpeechToTextAPI.get_response(),
251+
tags=[_('Application')] # type: ignore
252+
)
253+
@has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
254+
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
255+
RoleConstants.USER.get_workspace_role(),
256+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
257+
def get(self, request: Request, workspace_id, application_id: str):
258+
return result.success(ApplicationOperateSerializer(
259+
data={'mcp_servers': request.query_params.get('mcp_servers')}).get_mcp_servers())
260+
261+
262+
class SpeechToText(APIView):
263+
authentication_classes = [TokenAuth]
264+
265+
@extend_schema(
266+
methods=['POST'],
267+
description=_("speech to text"),
268+
summary=_("speech to text"),
269+
operation_id=_("speech to text"), # type: ignore
270+
parameters=SpeechToTextAPI.get_parameters(),
271+
request=SpeechToTextAPI.get_request(),
272+
responses=SpeechToTextAPI.get_response(),
273+
tags=[_('Application')] # type: ignore
274+
)
275+
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
276+
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
277+
RoleConstants.USER.get_workspace_role(),
278+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
279+
def post(self, request: Request, workspace_id: str, application_id: str):
280+
return result.success(
281+
ApplicationOperateSerializer(
282+
data={'application_id': application_id, 'workspace_id': workspace_id, 'user_id': request.user.id})
283+
.speech_to_text({'file': request.FILES.get('file')}))
284+
285+
286+
class TextToSpeech(APIView):
287+
authentication_classes = [TokenAuth]
288+
289+
@extend_schema(
290+
methods=['POST'],
291+
description=_("text to speech"),
292+
summary=_("text to speech"),
293+
operation_id=_("text to speech"), # type: ignore
294+
parameters=TextToSpeechAPI.get_parameters(),
295+
request=TextToSpeechAPI.get_request(),
296+
responses=TextToSpeechAPI.get_response(),
297+
tags=[_('Application')] # type: ignore
298+
)
299+
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
300+
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
301+
RoleConstants.USER.get_workspace_role(),
302+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
303+
def post(self, request: Request, workspace_id: str, application_id: str):
304+
byte_data = ApplicationOperateSerializer(
305+
data={'application_id': application_id, 'workspace_id': workspace_id,
306+
'user_id': request.user.id}).text_to_speech(request.data)
307+
return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3',
308+
'Content-Disposition': 'attachment; filename="abc.mp3"'})
309+
310+
311+
class PlayDemoText(APIView):
312+
authentication_classes = [TokenAuth]
313+
314+
@extend_schema(
315+
methods=['POST'],
316+
description=_("PlayDemo"),
317+
summary=_("PlayDemo"),
318+
operation_id=_("PlayDemo"), # type: ignore
319+
parameters=PlayDemoTextAPI.get_parameters(),
320+
request=PlayDemoTextAPI.get_request(),
321+
responses=PlayDemoTextAPI.get_response(),
322+
tags=[_('Application')] # type: ignore
323+
)
324+
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
325+
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
326+
RoleConstants.USER.get_workspace_role(),
327+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
328+
@log(menu='Application', operate="trial listening",
329+
get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')))
330+
def post(self, request: Request, workspace_id: str, application_id: str):
331+
byte_data = ApplicationOperateSerializer(
332+
data={'application_id': application_id, 'workspace_id': workspace_id,
333+
'user_id': request.user.id}).play_demo_text(request.data)
334+
return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3',
335+
'Content-Disposition': 'attachment; filename="abc.mp3"'})

0 commit comments

Comments
 (0)