-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
380 lines (338 loc) · 16 KB
/
main.py
File metadata and controls
380 lines (338 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
from ncnk_message import (
MessageServer,
Router,
RouteConfig,
TargetConfig,
MessageBase,
Seg,
FormatInfo,
)
import asyncio
from typing import List, Tuple, Dict, Optional
import importlib
import toml
import random
from pathlib import Path
from tts_src.config import Config
from tts_src.logger import logger
from tts_src.plugins.base_tts_model import BaseTTSModel
from tts_src.utils.audio_encode import encode_audio, encode_audio_stream
from tts_src.utils import post_process
class TTSPipeline:
tts_list: List[BaseTTSModel] = []
def __init__(self, config_path: str): # sourcery skip: dict-comprehension
self.config: Config = Config(config_path)
# 根据配置刷新日志级别
from tts_src.logger import set_logging_level
set_logging_level(self.config.config_data["debug"].get("logging_level", "INFO"))
self.server = MessageServer(
host=self.config.server.host,
port=self.config.server.port,
)
# 设置路由
route_config = {}
for platform, url in self.config.routes.items():
route_config[platform] = TargetConfig(url=url, token=None)
self.router = Router(RouteConfig(route_config))
self.server.register_message_handler(self.server_handle)
self.router.register_class_handler(self.client_handle)
# 按群/用户分组的文本缓冲队列和处理任务
self.text_buffer_dict: Dict[str, asyncio.Queue[Tuple[str, MessageBase]]] = {}
self.buffer_task_dict: Dict[str, asyncio.Task] = {}
def import_module(self):
"""动态导入TTS适配"""
for tts in self.config.enabled_plugin.enabled:
# 动态导入模块
module_name = f"tts_src.plugins.{tts}"
try:
module = importlib.import_module(module_name)
tts_class: BaseTTSModel = module.TTSModel()
self.tts_list.append(tts_class)
except ImportError as e:
logger.error(f"Error importing {module_name}: {e}")
raise
except AttributeError as e:
logger.error(f"Error accessing TTSModel in {module_name}: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error importing {module_name}: {e}")
raise
async def start(self):
"""启动服务器和路由,并导入设定的模块"""
self.import_module()
py_project_path = Path(__file__).parent / "pyproject.toml"
toml_data = toml.load(py_project_path)
logger.info(f"版本信息\n\n当前版本: {toml_data['project']['version']}\n")
# 创建任务而不是直接返回 gather 结果
self.server_task = asyncio.create_task(self.server.run())
self.router_task = asyncio.create_task(self.router.run())
# 返回任务以便外部可以等待或取消
return self.server_task, self.router_task
async def server_handle(self, message_data: dict):
"""处理服务器收到的消息"""
message = MessageBase.from_dict(message_data)
if message.message_info.format_info and "voice" in message.message_info.format_info.accept_format:
message.message_info.format_info.accept_format.append("tts_text")
await self.router.send_message(message)
def process_seg(self, seg: Seg) -> Tuple[str, Optional[str]]:
"""处理消息段,提取文本内容与显式语种"""
message_text = ""
text_lang: Optional[str] = None
if seg.type == "seglist":
for s in seg.data:
child_text, child_lang = self.process_seg(s)
message_text += child_text
if child_lang:
text_lang = child_lang
if seg.type == "tts_text":
if isinstance(seg.data, dict):
message_text += seg.data.get("text", "")
text_lang = seg.data.get("lang")
else:
message_text += seg.data
return message_text, text_lang
async def client_handle(self, message_dict: dict) -> None:
# sourcery skip: remove-redundant-if
"""处理客户端收到的消息并进行TTS转换(分群缓冲)"""
message = MessageBase.from_dict(message_dict)
stream_mode = self.config.tts_base_config.stream_mode
if message.message_segment.type != "tts_text" and random.random() > self.config.probability.voice_probability:
# 如果概率不满足,直接透传消息
await self.server.send_message(message)
return
if stream_mode:
await self.send_voice_stream(message)
return
message_text, text_lang = self.process_seg(message.message_segment)
if not text_lang and message.message_info.additional_config:
text_lang = message.message_info.additional_config.get(
"tts_language"
) or message.message_info.additional_config.get("text_lang")
if message_text == "":
# 非文本消息直接透传
await self.server.send_message(message)
return
if not message_text:
logger.warning("处理文本为空,跳过发送")
return
# 获取分组ID(优先群id,否则用户id)
group_id = getattr(message.message_info.group_info, "group_id", None)
if group_id is None:
logger.warning("没有群消息id使用用户id代替")
group_id = getattr(message.message_info.user_info, "user_id", None)
if not group_id:
logger.warning("无法定位目标发送位置,跳过TTS处理")
await self.server.send_message(message)
return
group_id = str(group_id)
# 保证队列存在
if group_id not in self.text_buffer_dict:
self.text_buffer_dict[group_id] = asyncio.Queue()
# 创建处理任务
if group_id not in self.buffer_task_dict:
self.buffer_task_dict[group_id] = asyncio.create_task(self._buffer_queue_handler(group_id))
# 将文本加入队列
await self.text_buffer_dict[group_id].put((message_text, message))
async def _buffer_queue_handler(self, group_id: str) -> None:
"""处理每个群/用户的缓冲队列,合成语音并发送"""
message_text, latest_message_obj = await self.text_buffer_dict[group_id].get()
self.text_buffer_dict[group_id].task_done()
if not message_text or not latest_message_obj:
logger.warning("数据为空,跳过处理")
await self.cleanup_task(group_id)
return
text: str = message_text.strip()
logger.info(f"[聊天: {group_id}]将合成文本: {text}")
message = latest_message_obj
text_lang = None
if message.message_info.additional_config:
text_lang = message.message_info.additional_config.get(
"tts_language"
) or message.message_info.additional_config.get("text_lang")
new_seg = await self.get_voice_no_stream(text, message.message_info.platform, text_lang=text_lang)
try:
if not new_seg:
logger.warning("语音消息为空,跳过发送")
await self.cleanup_task(group_id)
return
if not message.message_info.format_info:
message.message_info.format_info = FormatInfo(
content_format=[],
accept_format=[],
)
message.message_segment = new_seg
message.message_info.format_info.content_format = ["voice"]
if not message.message_info.additional_config:
message.message_info.additional_config = {}
message.message_info.additional_config["original_text"] = text
if text_lang:
message.message_info.additional_config["text_lang"] = text_lang
message.message_info.additional_config["tts_language"] = text_lang
logger.debug(
f"TTS->Napcat 即将发送: platform={message.message_info.platform}, formats={message.message_info.format_info.content_format}"
)
ok = await self.server.send_message(message)
logger.info(
f"TTS->Napcat send: platform={message.message_info.platform}, ok={ok}, formats={message.message_info.format_info.content_format}"
)
if not ok:
logger.warning("send_message 返回 False,检查平台映射或连接状态")
except Exception as exc:
logger.exception(f"TTS->Napcat 发送异常: {exc}")
await self.cleanup_task(group_id)
return
async def cleanup_task(self, group_id: str):
task = self.buffer_task_dict.pop(group_id)
task.cancel()
async def get_voice_no_stream(self, text: str, platform: str, text_lang: str | None = None) -> Seg | None:
"""获取语音消息段"""
if not self.tts_list:
logger.warning("没有启用任何tts,跳过处理")
return None
# tts_class = random.choice(self.tts_list)
tts_class = self.tts_list[0]
try:
# 使用非流式TTS
audio_data = await tts_class.tts(text=text, platform=platform, text_lang=text_lang)
if self.config.tts_base_config.post_process:
# 如果启用了后处理,进行电话语音模拟
audio_data = post_process.simulate_telephone_voice(audio_data)
# 对整个音频数据进行base64编码
encoded_audio = encode_audio(audio_data)
logger.debug(f"生成语音数据长度: {len(encoded_audio)} (base64)")
# 创建语音消息
return Seg(type="voice", data=encoded_audio)
except Exception as e:
logger.error(f"TTS处理过程中发生错误: {str(e)}")
logger.info(f"文本为: {text}")
return None
async def send_voice_stream(self, message: MessageBase) -> None:
"""流式发送语音消息"""
platform = message.message_info.platform
message_text, text_lang = self.process_seg(message.message_segment)
if not text_lang and message.message_info.additional_config:
text_lang = message.message_info.additional_config.get(
"tts_language"
) or message.message_info.additional_config.get("text_lang")
if not message_text:
logger.warning("处理文本为空,跳过发送")
return
text = message_text
if not self.tts_list:
logger.warning("没有启用任何tts,跳过处理")
return None
# tts_class = random.choice(self.tts_list)
tts_class = self.tts_list[0]
try:
audio_stream = await tts_class.tts_stream(text=text, platform=platform, text_lang=text_lang)
# 从音频流中读取和处理数据
for chunk in audio_stream:
if chunk: # 确保chunk不为空
try:
# 对音频数据进行base64编码
encoded_chunk = encode_audio_stream(chunk)
# 创建语音消息
new_seg = Seg(type="voice_stream", data=encoded_chunk)
message.message_segment = new_seg
message.message_info.format_info.content_format = ["voice_stream"]
if not message.message_info.additional_config:
message.message_info.additional_config = {}
message.message_info.additional_config["original_text"] = text
if text_lang:
message.message_info.additional_config["text_lang"] = text_lang
# 发送到下游
try:
ok = await self.server.send_message(message)
logger.debug(f"流式分片发送: platform={message.message_info.platform}, ok={ok}")
if not ok:
logger.warning(f"流式语音发送失败 platform={message.message_info.platform}")
except Exception as exc:
logger.exception(f"流式发送异常: {exc}")
except Exception as e:
logger.error(f"处理音频块时发生错误: {str(e)}")
continue
logger.info("流式语音消息发送完成")
except Exception as e:
logger.error(f"TTS处理过程中发生错误: {str(e)}")
logger.info(f"文本为: {text}")
return None
async def stop(self):
"""停止服务器和路由"""
logger.info("正在停止TTS服务...")
# 停止所有正在运行的缓冲任务
for _, task in list(self.buffer_task_dict.items()):
if not task.done() and not task.cancelled():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"取消缓冲任务时出错: {e}")
# 如果有任务属性,先取消这些任务
tasks_to_cancel = []
if hasattr(self, "server_task") and not self.server_task.done():
self.server_task.cancel()
tasks_to_cancel.append(self.server_task)
if hasattr(self, "router_task") and not self.router_task.done():
self.router_task.cancel()
tasks_to_cancel.append(self.router_task)
# 等待任务取消完成
if tasks_to_cancel:
try:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
except Exception as e:
logger.error(f"等待任务取消时出错: {e}")
# 安全地停止路由器(先停止路由器,因为它包含客户端连接)
try:
await self.router.stop()
except Exception as e:
logger.error(f"停止路由器时发生错误: {e}")
# 安全地停止服务器
try:
await self.server.stop()
except Exception as e:
logger.error(f"停止服务器时发生错误: {e}")
# 给一点时间让连接完全关闭
await asyncio.sleep(0.1)
logger.info("TTS服务已停止")
async def main():
"""主程序入口"""
config_path = Path(__file__).parent / "configs" / "base.toml"
pipeline = TTSPipeline(str(config_path))
try:
logger.info("正在启动TTS服务...")
# 启动服务
server_task, router_task = await pipeline.start()
logger.info("TTS服务已启动,按 Ctrl+C 退出")
# 等待任务完成或中断
await asyncio.gather(server_task, router_task)
except KeyboardInterrupt:
logger.debug("\n接收到键盘中断信号...")
except Exception as e:
logger.error(f"运行过程中发生错误: {str(e)}")
finally:
logger.info("正在关闭服务...")
try:
# 增加超时时间,确保有足够时间清理资源
await asyncio.wait_for(pipeline.stop(), timeout=15.0)
logger.info("服务已安全关闭")
except asyncio.TimeoutError:
logger.warning("关闭服务超时,强制退出")
except Exception as e:
logger.error(f"关闭服务时发生错误: {str(e)}")
# 额外的清理步骤:等待一小段时间让所有资源完全释放
await asyncio.sleep(0.2)
if __name__ == "__main__":
try:
# 使用 asyncio.run() 来运行程序,这是现代化的做法
# 它会自动处理事件循环的创建和清理
asyncio.run(main())
except KeyboardInterrupt:
logger.info("\n程序已退出")
except Exception as e:
logger.error(f"程序启动失败: {str(e)}")
finally:
# 给系统一点时间完成所有清理工作
import time
time.sleep(0.1)