From a3d469d49c9c7faff027e25be4a8ebf402c48da4 Mon Sep 17 00:00:00 2001 From: AliveGh0st <1724728802@qq.com> Date: Sat, 26 Apr 2025 17:50:08 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=A8feat:=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9Gemini=E6=A8=A1=E5=9E=8B=E7=9A=84=E9=9F=B3=E9=A2=91?= =?UTF-8?q?=E5=A4=84=E7=90=86=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=20ProviderRequest=20=E4=BB=A5=E5=8C=85=E5=90=AB=E9=9F=B3?= =?UTF-8?q?=E9=A2=91=20URL=20=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 21 ++- astrbot/core/provider/entities.py | 2 + .../core/provider/sources/gemini_source.py | 128 ++++++++++++++---- 3 files changed, 122 insertions(+), 29 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fd70275d8..e9280c3b5 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -14,7 +14,7 @@ ResultContentType, MessageChain, ) -from astrbot.core.message.components import Image +from astrbot.core.message.components import Image, Record from astrbot.core import logger from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entities import ( @@ -77,16 +77,33 @@ async def process( ) else: - req = ProviderRequest(prompt="", image_urls=[]) + req = ProviderRequest(prompt="", image_urls=[], audio_urls=[]) if self.provider_wake_prefix: if not event.message_str.startswith(self.provider_wake_prefix): return req.prompt = event.message_str[len(self.provider_wake_prefix) :] req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + + # 处理消息中的图片和音频 + has_audio = False for comp in event.message_obj.message: if isinstance(comp, Image): image_path = await comp.convert_to_file_path() req.image_urls.append(image_path) + elif isinstance(comp, Record): + # 处理音频消息 + audio_path = await comp.convert_to_file_path() + logger.info(f"检测到音频消息,路径: {audio_path}") + has_audio = True + if hasattr(req, "audio_urls"): + req.audio_urls.append(audio_path) + else: + # 为了兼容性,如果ProviderRequest没有audio_urls属性 + req.audio_urls = [audio_path] + + # 如果只有音频没有文本,添加默认文本 + if not req.prompt and has_audio: + req.prompt = "[用户发送的音频将其视为文本输入与其进行聊天]" # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 6ad67da55..6bd28d01e 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -96,6 +96,8 @@ class ProviderRequest: """会话 ID""" image_urls: List[str] = None """图片 URL 列表""" + audio_urls: List[str] = None + """音频 URL 列表""" func_tool: FuncCall = None """可用的函数工具""" contexts: List = None diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a175a3d68..00fa2f61a 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,6 +3,8 @@ import json import logging import random +import os +import mimetypes from typing import Dict, List, Optional from collections.abc import AsyncGenerator @@ -193,6 +195,12 @@ def process_image_url(image_url_dict: dict) -> types.Part: mime_type = url.split(":")[1].split(";")[0] image_bytes = base64.b64decode(url.split(",", 1)[1]) return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) + + def process_inline_data(inline_data_dict: dict) -> types.Part: + """处理内联数据,如音频""" # TODO: 处理视频? + mime_type = inline_data_dict["mime_type"] + data = inline_data_dict.get("data", "") + return types.Part.from_bytes(data=data, mime_type=mime_type) def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None: if contents and isinstance(contents[-1], content_cls): @@ -212,12 +220,15 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont if role == "user": if isinstance(content, list): - parts = [ - types.Part.from_text(text=item["text"] or " ") - if item["type"] == "text" - else process_image_url(item["image_url"]) - for item in content - ] + parts = [] + for item in content: + if item["type"] == "text": + parts.append(types.Part.from_text(text=item["text"] or " ")) + elif item["type"] == "image_url": + parts.append(process_image_url(item["image_url"])) + elif item["type"] == "inline_data": + # 处理内联数据,如音频 + parts.append(process_inline_data(item["inline_data"])) else: parts = [create_text_part(content)] append_or_extend(gemini_contents, parts, types.UserContent) @@ -447,13 +458,14 @@ async def text_chat( prompt: str, session_id: str = None, image_urls: List[str] = None, + audio_urls: List[str] = None, func_tool: FuncCall = None, contexts=[], system_prompt=None, tool_calls_result=None, **kwargs, ) -> LLMResponse: - new_record = await self.assemble_context(prompt, image_urls) + new_record = await self.assemble_context(prompt, image_urls, audio_urls) context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -486,14 +498,15 @@ async def text_chat_stream( self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: List[str] = None, + audio_urls: List[str] = None, func_tool: FuncCall = None, contexts=[], system_prompt=None, tool_calls_result=None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: - new_record = await self.assemble_context(prompt, image_urls) + new_record = await self.assemble_context(prompt, image_urls, audio_urls) context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -545,30 +558,55 @@ def set_key(self, key): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: List[str] = None): + async def assemble_context(self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None): """ 组装上下文。 """ - if image_urls: + has_media = (image_urls and len(image_urls) > 0) or (audio_urls and len(audio_urls) > 0) + + if has_media: user_content = { "role": "user", - "content": [{"type": "text", "text": text if text else "[图片]"}], + "content": [{"type": "text", "text": text if text else "[媒体内容]"}], } - for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - image_data = await self.encode_image_bs64(image_path) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - image_data = await self.encode_image_bs64(image_path) - else: - image_data = await self.encode_image_bs64(image_url) - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - continue - user_content["content"].append( - {"type": "image_url", "image_url": {"url": image_data}} - ) + + # 处理图片 + if image_urls: + for image_url in image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) + else: + image_data = await self.encode_image_bs64(image_url) + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + continue + user_content["content"].append( + {"type": "image_url", "image_url": {"url": image_data}} + ) + + # 处理音频 + if audio_urls: + for audio_url in audio_urls: + audio_bytes, mime_type = await self.encode_audio_data(audio_url) + if not audio_bytes or not mime_type: + logger.warning(f"音频 {audio_url} 处理失败,将忽略。") + continue + + # 添加音频数据 + user_content["content"].append( + { + "type": "inline_data", + "inline_data": { + "mime_type": mime_type, + "data": audio_bytes + } + } + ) + return user_content else: return {"role": "user", "content": text} @@ -584,5 +622,41 @@ async def encode_image_bs64(self, image_url: str) -> str: return "data:image/jpeg;base64," + image_bs64 return "" + async def encode_audio_data(self, audio_url: str) -> tuple: + """ + 读取音频文件并返回二进制数据 + + Returns: + tuple: (音频二进制数据, MIME类型) + """ + try: + # 直接读取文件二进制数据 + with open(audio_url, "rb") as f: + audio_bytes = f.read() + + # 推断 MIME 类型 + mime_type = mimetypes.guess_type(audio_url)[0] + if not mime_type: + # 根据文件扩展名确定 MIME 类型 + extension = os.path.splitext(audio_url)[1].lower() + if extension == '.wav': + mime_type = 'audio/wav' + elif extension == '.mp3': + mime_type = 'audio/mpeg' + elif extension == '.ogg': + mime_type = 'audio/ogg' + elif extension == '.flac': + mime_type = 'audio/flac' + elif extension == '.m4a': + mime_type = 'audio/mp4' + else: + mime_type = 'audio/wav' # 默认 + + logger.info(f"音频文件处理成功: {audio_url},mime类型: {mime_type},大小: {len(audio_bytes)} 字节") + return audio_bytes, mime_type + except Exception as e: + logger.error(f"音频文件处理失败: {e}") + return None, None + async def terminate(self): logger.info("Google GenAI 适配器已终止。") From da0eb2a46366f472264128052f68c046845fe4a9 Mon Sep 17 00:00:00 2001 From: AliveGh0st <1724728802@qq.com> Date: Sun, 27 Apr 2025 08:55:56 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=90=9Bfix:=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=AF=B9req=E5=AF=B9=E8=B1=A1audio=5Furls=E5=B1=9E=E6=80=A7?= =?UTF-8?q?=E6=97=A0=E7=94=A8=E7=9A=84=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/process_stage/method/llm_request.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index e9280c3b5..19bac39ad 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -95,11 +95,7 @@ async def process( audio_path = await comp.convert_to_file_path() logger.info(f"检测到音频消息,路径: {audio_path}") has_audio = True - if hasattr(req, "audio_urls"): - req.audio_urls.append(audio_path) - else: - # 为了兼容性,如果ProviderRequest没有audio_urls属性 - req.audio_urls = [audio_path] + req.audio_urls.append(audio_path) # 如果只有音频没有文本,添加默认文本 if not req.prompt and has_audio: From 7ac151a6393b2e4c1f586321447d0337e2b82448 Mon Sep 17 00:00:00 2001 From: AliveGh0st <1724728802@qq.com> Date: Tue, 29 Apr 2025 16:31:34 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix:=20=E5=B0=86=E9=9F=B3=E9=A2=91=E4=B8=8A?= =?UTF-8?q?=E4=B8=8B=E6=96=87=E6=AD=A3=E7=A1=AE=E7=BB=84=E8=A3=85=E4=B8=BA?= =?UTF-8?q?OpenAI=E6=A0=BC=E5=BC=8Fbase64=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/gemini_source.py | 104 +++++++++--------- 1 file changed, 50 insertions(+), 54 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 00fa2f61a..611fd4214 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -196,11 +196,27 @@ def process_image_url(image_url_dict: dict) -> types.Part: image_bytes = base64.b64decode(url.split(",", 1)[1]) return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) - def process_inline_data(inline_data_dict: dict) -> types.Part: - """处理内联数据,如音频""" # TODO: 处理视频? - mime_type = inline_data_dict["mime_type"] - data = inline_data_dict.get("data", "") - return types.Part.from_bytes(data=data, mime_type=mime_type) + def process_input_audio(input_audio_dict: dict) -> types.Part: + """处理音频数据""" + audio_base64 = input_audio_dict.get("data", "") + audio_format = input_audio_dict.get("format", "") + + # 将 base64 字符串解码为二进制数据 + audio_bytes = base64.b64decode(audio_base64) + + # 根据音频格式确定 MIME 类型 + mime_type_map = { + "wav": "audio/wav", + "mp3": "audio/mp3", + "aiff": "audio/aiff", + "aac": "audio/aac", + "ogg": "audio/ogg", + "flac": "audio/flac", + } + mime_type = mime_type_map.get(audio_format, "audio/wav") + + logger.debug(f"处理 OpenAI 格式音频数据,格式: {audio_format}, MIME类型: {mime_type}") + return types.Part.from_bytes(data=audio_bytes, mime_type=mime_type) def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None: if contents and isinstance(contents[-1], content_cls): @@ -226,9 +242,8 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont parts.append(types.Part.from_text(text=item["text"] or " ")) elif item["type"] == "image_url": parts.append(process_image_url(item["image_url"])) - elif item["type"] == "inline_data": - # 处理内联数据,如音频 - parts.append(process_inline_data(item["inline_data"])) + elif item["type"] == "input_audio": + parts.append(process_input_audio(item["input_audio"])) else: parts = [create_text_part(content)] append_or_extend(gemini_contents, parts, types.UserContent) @@ -560,7 +575,7 @@ def set_key(self, key): async def assemble_context(self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None): """ - 组装上下文。 + 组装上下文。将用户输入(文本、图片和音频)组装成 OpenAI 格式的上下文数据。 """ has_media = (image_urls and len(image_urls) > 0) or (audio_urls and len(audio_urls) > 0) @@ -586,26 +601,22 @@ async def assemble_context(self, text: str, image_urls: List[str] = None, audio_ continue user_content["content"].append( {"type": "image_url", "image_url": {"url": image_data}} - ) - + ) # 处理音频 if audio_urls: for audio_url in audio_urls: - audio_bytes, mime_type = await self.encode_audio_data(audio_url) - if not audio_bytes or not mime_type: - logger.warning(f"音频 {audio_url} 处理失败,将忽略。") - continue - - # 添加音频数据 - user_content["content"].append( - { - "type": "inline_data", - "inline_data": { - "mime_type": mime_type, - "data": audio_bytes - } - } - ) + try: + audio_base64, audio_format = await self.encode_audio_bs64(audio_url) + if audio_base64 and audio_format: + user_content["content"].append({ + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": audio_format + } + }) + except Exception as e: + logger.error(f"音频文件处理失败: {audio_url}, 错误: {e}") return user_content else: @@ -622,40 +633,25 @@ async def encode_image_bs64(self, image_url: str) -> str: return "data:image/jpeg;base64," + image_bs64 return "" - async def encode_audio_data(self, audio_url: str) -> tuple: + async def encode_audio_bs64(self, audio_url: str) -> tuple: """ - 读取音频文件并返回二进制数据 - - Returns: - tuple: (音频二进制数据, MIME类型) + 将音频文件转换为 base64 编码 """ try: - # 直接读取文件二进制数据 + # 读取音频文件并编码为 base64 with open(audio_url, "rb") as f: audio_bytes = f.read() - - # 推断 MIME 类型 - mime_type = mimetypes.guess_type(audio_url)[0] - if not mime_type: - # 根据文件扩展名确定 MIME 类型 - extension = os.path.splitext(audio_url)[1].lower() - if extension == '.wav': - mime_type = 'audio/wav' - elif extension == '.mp3': - mime_type = 'audio/mpeg' - elif extension == '.ogg': - mime_type = 'audio/ogg' - elif extension == '.flac': - mime_type = 'audio/flac' - elif extension == '.m4a': - mime_type = 'audio/mp4' - else: - mime_type = 'audio/wav' # 默认 - - logger.info(f"音频文件处理成功: {audio_url},mime类型: {mime_type},大小: {len(audio_bytes)} 字节") - return audio_bytes, mime_type + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + # 确定音频格式 + extension = os.path.splitext(audio_url)[1].lower() + # 移除扩展名前面的点号 + audio_format = extension[1:] if extension.startswith('.') else extension + + logger.info(f"音频文件转换成功: {audio_url},格式: {audio_format},大小: {len(audio_bytes)} 字节") + return audio_base64, audio_format except Exception as e: - logger.error(f"音频文件处理失败: {e}") + logger.error(f"音频文件转换失败: {audio_url}, 错误: {e}") return None, None async def terminate(self): From 83bc1c5a192e9b8aefe5e7477c17786fb53dc1be Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 1 May 2025 23:15:02 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=8E=88=20perf:=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 4 +- astrbot/core/provider/provider.py | 8 +- .../core/provider/sources/anthropic_source.py | 2 + .../core/provider/sources/dashscope_source.py | 3 + astrbot/core/provider/sources/dify_source.py | 3 + .../core/provider/sources/gemini_source.py | 89 +++++++++++-------- .../core/provider/sources/openai_source.py | 8 ++ astrbot/core/provider/sources/zhipu_source.py | 15 ++-- 8 files changed, 87 insertions(+), 45 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 19bac39ad..fdab7c765 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -83,7 +83,7 @@ async def process( return req.prompt = event.message_str[len(self.provider_wake_prefix) :] req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() - + # 处理消息中的图片和音频 has_audio = False for comp in event.message_obj.message: @@ -96,7 +96,7 @@ async def process( logger.info(f"检测到音频消息,路径: {audio_path}") has_audio = True req.audio_urls.append(audio_path) - + # 如果只有音频没有文本,添加默认文本 if not req.prompt and has_audio: req.prompt = "[用户发送的音频将其视为文本输入与其进行聊天]" diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 96547c5c2..e9b3d1cd1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -91,6 +91,8 @@ async def text_chat( contexts: List = None, system_prompt: str = None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -98,10 +100,12 @@ async def text_chat( Args: prompt: 提示词 session_id: 会话 ID(此属性已经被废弃) - image_urls: 图片 URL 列表 + image_urls: 图片 URL 列表,需要模型支持。 tools: Function-calling 工具 contexts: 上下文 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling + audio_urls: 传给模型的音频 URL 列表,需要模型支持。 + video_urls: 传给模型的视频 URL 列表,需要模型支持。 kwargs: 其他参数 Notes: @@ -119,6 +123,8 @@ async def text_chat_stream( contexts: List = None, system_prompt: str = None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 319515c52..4c43c473f 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -107,6 +107,8 @@ async def text_chat( contexts=[], system_prompt=None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: if not prompt: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 2c4930692..81b052588 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -72,6 +72,9 @@ async def text_chat( func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: # 获得会话变量 diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 78e3760c1..5234ffc85 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -64,6 +64,9 @@ async def text_chat( func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: result = "" diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 611fd4214..eb6ccb989 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,7 +4,6 @@ import logging import random import os -import mimetypes from typing import Dict, List, Optional from collections.abc import AsyncGenerator @@ -164,13 +163,17 @@ async def _prepare_query_config( return types.GenerateContentConfig( system_instruction=system_instruction, temperature=temperature, - max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"), + max_output_tokens=payloads.get("max_tokens") + or payloads.get("maxOutputTokens"), top_p=payloads.get("top_p") or payloads.get("topP"), top_k=payloads.get("top_k") or payloads.get("topK"), - frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"), - presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"), + frequency_penalty=payloads.get("frequency_penalty") + or payloads.get("frequencyPenalty"), + presence_penalty=payloads.get("presence_penalty") + or payloads.get("presencePenalty"), stop_sequences=payloads.get("stop") or payloads.get("stopSequences"), - response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"), + response_logprobs=payloads.get("response_logprobs") + or payloads.get("responseLogprobs"), logprobs=payloads.get("logprobs"), seed=payloads.get("seed"), response_modalities=modalities, @@ -195,15 +198,15 @@ def process_image_url(image_url_dict: dict) -> types.Part: mime_type = url.split(":")[1].split(";")[0] image_bytes = base64.b64decode(url.split(",", 1)[1]) return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) - + def process_input_audio(input_audio_dict: dict) -> types.Part: """处理音频数据""" audio_base64 = input_audio_dict.get("data", "") audio_format = input_audio_dict.get("format", "") - + # 将 base64 字符串解码为二进制数据 audio_bytes = base64.b64decode(audio_base64) - + # 根据音频格式确定 MIME 类型 mime_type_map = { "wav": "audio/wav", @@ -214,11 +217,17 @@ def process_input_audio(input_audio_dict: dict) -> types.Part: "flac": "audio/flac", } mime_type = mime_type_map.get(audio_format, "audio/wav") - - logger.debug(f"处理 OpenAI 格式音频数据,格式: {audio_format}, MIME类型: {mime_type}") + + logger.debug( + f"处理 OpenAI 格式音频数据,格式: {audio_format}, MIME类型: {mime_type}" + ) return types.Part.from_bytes(data=audio_bytes, mime_type=mime_type) - def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None: + def append_or_extend( + contents: list[types.Content], + part: list[types.Part], + content_cls: type[types.Content], + ) -> None: if contents and isinstance(contents[-1], content_cls): contents[-1].parts.extend(part) else: @@ -252,7 +261,7 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont if content: parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) - elif not native_tool_enabled and "tool_calls" in message : + elif not native_tool_enabled and "tool_calls" in message: parts = [ types.Part.from_function_call( name=tool["function"]["name"], @@ -338,9 +347,7 @@ def _process_content_parts( chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain) - async def _query( - self, payloads: dict, tools: FuncCall - ) -> LLMResponse: + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), @@ -352,7 +359,7 @@ async def _query( modalities.append("Image") conversation = self._prepare_conversation(payloads) - temperature=payloads.get("temperature", 0.7) + temperature = payloads.get("temperature", 0.7) result: Optional[types.GenerateContentResponse] = None while True: @@ -473,11 +480,12 @@ async def text_chat( prompt: str, session_id: str = None, image_urls: List[str] = None, - audio_urls: List[str] = None, func_tool: FuncCall = None, contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls, audio_urls) @@ -514,11 +522,12 @@ async def text_chat_stream( prompt: str, session_id: str = None, image_urls: List[str] = None, - audio_urls: List[str] = None, func_tool: FuncCall = None, contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: new_record = await self.assemble_context(prompt, image_urls, audio_urls) @@ -573,18 +582,22 @@ def set_key(self, key): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None): + async def assemble_context( + self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None + ): """ 组装上下文。将用户输入(文本、图片和音频)组装成 OpenAI 格式的上下文数据。 """ - has_media = (image_urls and len(image_urls) > 0) or (audio_urls and len(audio_urls) > 0) - + has_media = (image_urls and len(image_urls) > 0) or ( + audio_urls and len(audio_urls) > 0 + ) + if has_media: user_content = { "role": "user", "content": [{"type": "text", "text": text if text else "[媒体内容]"}], } - + # 处理图片 if image_urls: for image_url in image_urls: @@ -601,23 +614,27 @@ async def assemble_context(self, text: str, image_urls: List[str] = None, audio_ continue user_content["content"].append( {"type": "image_url", "image_url": {"url": image_data}} - ) + ) # 处理音频 if audio_urls: for audio_url in audio_urls: try: - audio_base64, audio_format = await self.encode_audio_bs64(audio_url) + audio_base64, audio_format = await self.encode_audio_bs64( + audio_url + ) if audio_base64 and audio_format: - user_content["content"].append({ - "type": "input_audio", - "input_audio": { - "data": audio_base64, - "format": audio_format + user_content["content"].append( + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": audio_format, + }, } - }) + ) except Exception as e: logger.error(f"音频文件处理失败: {audio_url}, 错误: {e}") - + return user_content else: return {"role": "user", "content": text} @@ -642,13 +659,15 @@ async def encode_audio_bs64(self, audio_url: str) -> tuple: with open(audio_url, "rb") as f: audio_bytes = f.read() audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") - + # 确定音频格式 extension = os.path.splitext(audio_url)[1].lower() # 移除扩展名前面的点号 - audio_format = extension[1:] if extension.startswith('.') else extension - - logger.info(f"音频文件转换成功: {audio_url},格式: {audio_format},大小: {len(audio_bytes)} 字节") + audio_format = extension[1:] if extension.startswith(".") else extension + + logger.info( + f"音频文件转换成功: {audio_url},格式: {audio_format},大小: {len(audio_bytes)} 字节" + ) return audio_base64, audio_format except Exception as e: logger.error(f"音频文件转换失败: {audio_url}, 错误: {e}") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5399fbc32..434e891a8 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -226,6 +226,8 @@ async def _prepare_chat_payload( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -344,6 +346,8 @@ async def text_chat( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: payloads, context_query, func_tool = await self._prepare_chat_payload( @@ -354,6 +358,8 @@ async def text_chat( contexts, system_prompt, tool_calls_result, + audio_urls, + video_urls, **kwargs, ) @@ -413,6 +419,8 @@ async def text_chat_stream( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 2f7490317..2fb6791c7 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -1,13 +1,11 @@ from astrbot.core.db import BaseDatabase from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse from .openai_source import ProviderOpenAIOfficial -@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器") +@register_provider_adapter("zhipu_chat_completion", "智谱 Chat Completion 提供商适配器") class ProviderZhipu(ProviderOpenAIOfficial): def __init__( self, @@ -27,12 +25,15 @@ def __init__( async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: List[str] = None, - func_tool: FuncCall = None, + prompt, + session_id=None, + image_urls=[], + func_tool=None, contexts=[], system_prompt=None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls)