Skip to content

Commit 1608897

Browse files
committed
refactor: enhance text-to-speech processing by splitting content into chunks and merging audio segments
1 parent 06e759a commit 1608897

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
from application.flow.i_step_node import NodeResult
88
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
9+
from common.utils.common import _remove_empty_lines
910
from knowledge.models import FileSourceType
1011
from models_provider.tools import get_model_instance_by_model_workspace_id
1112
from oss.serializers.file import FileSerializer
13+
from pydub import AudioSegment
1214

1315

1416
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
@@ -41,32 +43,72 @@ def save_context(self, details, workflow_manage):
4143

4244
def execute(self, tts_model_id, chat_id,
4345
content, model_params_setting=None,
44-
**kwargs) -> NodeResult:
45-
self.context['content'] = content
46-
workspace_id = self.workflow_manage.get_body().get('workspace_id')
47-
model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
48-
**model_params_setting)
49-
audio_byte = model.text_to_speech(content)
50-
# 需要把这个音频文件存储到数据库中
51-
file_name = 'generated_audio.mp3'
52-
file = bytes_to_uploaded_file(audio_byte, file_name)
46+
max_length=1024, **kwargs) -> NodeResult:
47+
# 分割文本为合理片段
48+
content = _remove_empty_lines(content)
49+
content_chunks = [content[i:i + max_length]
50+
for i in range(0, len(content), max_length)]
51+
52+
# 生成并收集所有音频片段
53+
audio_segments = []
54+
temp_files = []
55+
56+
for i, chunk in enumerate(content_chunks):
57+
self.context['content'] = chunk
58+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
59+
model = get_model_instance_by_model_workspace_id(
60+
tts_model_id, workspace_id, **model_params_setting)
61+
62+
audio_byte = model.text_to_speech(chunk)
63+
64+
# 保存为临时音频文件用于合并
65+
temp_file = io.BytesIO(audio_byte)
66+
audio_segment = AudioSegment.from_file(temp_file)
67+
audio_segments.append(audio_segment)
68+
temp_files.append(temp_file)
69+
70+
# 合并所有音频片段
71+
combined_audio = AudioSegment.empty()
72+
for segment in audio_segments:
73+
combined_audio += segment
74+
75+
# 将合并后的音频转为字节流
76+
output_buffer = io.BytesIO()
77+
combined_audio.export(output_buffer, format="mp3")
78+
combined_bytes = output_buffer.getvalue()
79+
80+
# 存储合并后的音频文件
81+
file_name = 'combined_audio.mp3'
82+
file = bytes_to_uploaded_file(combined_bytes, file_name)
83+
5384
application = self.workflow_manage.work_flow_post_handler.chat_info.application
5485
meta = {
5586
'debug': False if application.id else True,
5687
'chat_id': chat_id,
5788
'application_id': str(application.id) if application.id else None,
5889
}
90+
5991
file_url = FileSerializer(data={
6092
'file': file,
6193
'meta': meta,
6294
'source_id': meta['application_id'],
6395
'source_type': FileSourceType.APPLICATION.value
6496
}).upload()
65-
# 拼接一个audio标签的src属性
66-
audio_label = f'<audio src="{file_url}" controls style = "width: 300px; height: 43px"></audio>'
97+
98+
# 生成音频标签
99+
audio_label = f'<audio src="{file_url}" controls style="width: 300px; height: 43px"></audio>'
67100
file_id = file_url.split('/')[-1]
68101
audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}]
69-
return NodeResult({'answer': audio_label, 'result': audio_list}, {})
102+
103+
# 关闭所有临时文件
104+
for temp_file in temp_files:
105+
temp_file.close()
106+
output_buffer.close()
107+
108+
return NodeResult({
109+
'answer': audio_label,
110+
'result': audio_list
111+
}, {})
70112

71113
def get_details(self, index: int, **kwargs):
72114
return {

apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import json
1515
import re
1616
import ssl
17+
18+
import requests
1719
import uuid_utils.compat as uuid
1820
from typing import Dict
1921

0 commit comments

Comments
 (0)