diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fd70275d8..fdab7c765 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -14,7 +14,7 @@ ResultContentType, MessageChain, ) -from astrbot.core.message.components import Image +from astrbot.core.message.components import Image, Record from astrbot.core import logger from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entities import ( @@ -77,16 +77,29 @@ async def process( ) else: - req = ProviderRequest(prompt="", image_urls=[]) + req = ProviderRequest(prompt="", image_urls=[], audio_urls=[]) if self.provider_wake_prefix: if not event.message_str.startswith(self.provider_wake_prefix): return req.prompt = event.message_str[len(self.provider_wake_prefix) :] req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + + # 处理消息中的图片和音频 + has_audio = False for comp in event.message_obj.message: if isinstance(comp, Image): image_path = await comp.convert_to_file_path() req.image_urls.append(image_path) + elif isinstance(comp, Record): + # 处理音频消息 + audio_path = await comp.convert_to_file_path() + logger.info(f"检测到音频消息,路径: {audio_path}") + has_audio = True + req.audio_urls.append(audio_path) + + # 如果只有音频没有文本,添加默认文本 + if not req.prompt and has_audio: + req.prompt = "[用户发送的音频将其视为文本输入与其进行聊天]" # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 6ad67da55..6bd28d01e 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -96,6 +96,8 @@ class ProviderRequest: """会话 ID""" image_urls: List[str] = None """图片 URL 列表""" + audio_urls: List[str] = None + """音频 URL 列表""" func_tool: FuncCall = None """可用的函数工具""" contexts: List = None diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 96547c5c2..e9b3d1cd1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -91,6 +91,8 @@ async def text_chat( contexts: List = None, system_prompt: str = None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -98,10 +100,12 @@ async def text_chat( Args: prompt: 提示词 session_id: 会话 ID(此属性已经被废弃) - image_urls: 图片 URL 列表 + image_urls: 图片 URL 列表,需要模型支持。 tools: Function-calling 工具 contexts: 上下文 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling + audio_urls: 传给模型的音频 URL 列表,需要模型支持。 + video_urls: 传给模型的视频 URL 列表,需要模型支持。 kwargs: 其他参数 Notes: @@ -119,6 +123,8 @@ async def text_chat_stream( contexts: List = None, system_prompt: str = None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 319515c52..4c43c473f 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -107,6 +107,8 @@ async def text_chat( contexts=[], system_prompt=None, tool_calls_result: ToolCallsResult = None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: if not prompt: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 2c4930692..81b052588 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -72,6 +72,9 @@ async def text_chat( func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: # 获得会话变量 diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 78e3760c1..5234ffc85 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -64,6 +64,9 @@ async def text_chat( func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: result = "" diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a175a3d68..eb6ccb989 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,6 +3,7 @@ import json import logging import random +import os from typing import Dict, List, Optional from collections.abc import AsyncGenerator @@ -162,13 +163,17 @@ async def _prepare_query_config( return types.GenerateContentConfig( system_instruction=system_instruction, temperature=temperature, - max_output_tokens=payloads.get("max_tokens") or payloads.get("maxOutputTokens"), + max_output_tokens=payloads.get("max_tokens") + or payloads.get("maxOutputTokens"), top_p=payloads.get("top_p") or payloads.get("topP"), top_k=payloads.get("top_k") or payloads.get("topK"), - frequency_penalty=payloads.get("frequency_penalty") or payloads.get("frequencyPenalty"), - presence_penalty=payloads.get("presence_penalty") or payloads.get("presencePenalty"), + frequency_penalty=payloads.get("frequency_penalty") + or payloads.get("frequencyPenalty"), + presence_penalty=payloads.get("presence_penalty") + or payloads.get("presencePenalty"), stop_sequences=payloads.get("stop") or payloads.get("stopSequences"), - response_logprobs=payloads.get("response_logprobs") or payloads.get("responseLogprobs"), + response_logprobs=payloads.get("response_logprobs") + or payloads.get("responseLogprobs"), logprobs=payloads.get("logprobs"), seed=payloads.get("seed"), response_modalities=modalities, @@ -194,7 +199,35 @@ def process_image_url(image_url_dict: dict) -> types.Part: image_bytes = base64.b64decode(url.split(",", 1)[1]) return types.Part.from_bytes(data=image_bytes, mime_type=mime_type) - def append_or_extend(contents: list[types.Content], part: list[types.Part], content_cls: type[types.Content]) -> None: + def process_input_audio(input_audio_dict: dict) -> types.Part: + """处理音频数据""" + audio_base64 = input_audio_dict.get("data", "") + audio_format = input_audio_dict.get("format", "") + + # 将 base64 字符串解码为二进制数据 + audio_bytes = base64.b64decode(audio_base64) + + # 根据音频格式确定 MIME 类型 + mime_type_map = { + "wav": "audio/wav", + "mp3": "audio/mp3", + "aiff": "audio/aiff", + "aac": "audio/aac", + "ogg": "audio/ogg", + "flac": "audio/flac", + } + mime_type = mime_type_map.get(audio_format, "audio/wav") + + logger.debug( + f"处理 OpenAI 格式音频数据,格式: {audio_format}, MIME类型: {mime_type}" + ) + return types.Part.from_bytes(data=audio_bytes, mime_type=mime_type) + + def append_or_extend( + contents: list[types.Content], + part: list[types.Part], + content_cls: type[types.Content], + ) -> None: if contents and isinstance(contents[-1], content_cls): contents[-1].parts.extend(part) else: @@ -212,12 +245,14 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont if role == "user": if isinstance(content, list): - parts = [ - types.Part.from_text(text=item["text"] or " ") - if item["type"] == "text" - else process_image_url(item["image_url"]) - for item in content - ] + parts = [] + for item in content: + if item["type"] == "text": + parts.append(types.Part.from_text(text=item["text"] or " ")) + elif item["type"] == "image_url": + parts.append(process_image_url(item["image_url"])) + elif item["type"] == "input_audio": + parts.append(process_input_audio(item["input_audio"])) else: parts = [create_text_part(content)] append_or_extend(gemini_contents, parts, types.UserContent) @@ -226,7 +261,7 @@ def append_or_extend(contents: list[types.Content], part: list[types.Part], cont if content: parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) - elif not native_tool_enabled and "tool_calls" in message : + elif not native_tool_enabled and "tool_calls" in message: parts = [ types.Part.from_function_call( name=tool["function"]["name"], @@ -312,9 +347,7 @@ def _process_content_parts( chain.append(Comp.Image.fromBytes(part.inline_data.data)) return MessageChain(chain=chain) - async def _query( - self, payloads: dict, tools: FuncCall - ) -> LLMResponse: + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: """非流式请求 Gemini API""" system_instruction = next( (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), @@ -326,7 +359,7 @@ async def _query( modalities.append("Image") conversation = self._prepare_conversation(payloads) - temperature=payloads.get("temperature", 0.7) + temperature = payloads.get("temperature", 0.7) result: Optional[types.GenerateContentResponse] = None while True: @@ -451,9 +484,11 @@ async def text_chat( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: - new_record = await self.assemble_context(prompt, image_urls) + new_record = await self.assemble_context(prompt, image_urls, audio_urls) context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -486,14 +521,16 @@ async def text_chat_stream( self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: List[str] = None, func_tool: FuncCall = None, contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: - new_record = await self.assemble_context(prompt, image_urls) + new_record = await self.assemble_context(prompt, image_urls, audio_urls) context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) @@ -545,30 +582,59 @@ def set_key(self, key): self.chosen_api_key = key self._init_client() - async def assemble_context(self, text: str, image_urls: List[str] = None): + async def assemble_context( + self, text: str, image_urls: List[str] = None, audio_urls: List[str] = None + ): """ - 组装上下文。 + 组装上下文。将用户输入(文本、图片和音频)组装成 OpenAI 格式的上下文数据。 """ - if image_urls: + has_media = (image_urls and len(image_urls) > 0) or ( + audio_urls and len(audio_urls) > 0 + ) + + if has_media: user_content = { "role": "user", - "content": [{"type": "text", "text": text if text else "[图片]"}], + "content": [{"type": "text", "text": text if text else "[媒体内容]"}], } - for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - image_data = await self.encode_image_bs64(image_path) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - image_data = await self.encode_image_bs64(image_path) - else: - image_data = await self.encode_image_bs64(image_url) - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - continue - user_content["content"].append( - {"type": "image_url", "image_url": {"url": image_data}} - ) + + # 处理图片 + if image_urls: + for image_url in image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) + else: + image_data = await self.encode_image_bs64(image_url) + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + continue + user_content["content"].append( + {"type": "image_url", "image_url": {"url": image_data}} + ) + # 处理音频 + if audio_urls: + for audio_url in audio_urls: + try: + audio_base64, audio_format = await self.encode_audio_bs64( + audio_url + ) + if audio_base64 and audio_format: + user_content["content"].append( + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": audio_format, + }, + } + ) + except Exception as e: + logger.error(f"音频文件处理失败: {audio_url}, 错误: {e}") + return user_content else: return {"role": "user", "content": text} @@ -584,5 +650,28 @@ async def encode_image_bs64(self, image_url: str) -> str: return "data:image/jpeg;base64," + image_bs64 return "" + async def encode_audio_bs64(self, audio_url: str) -> tuple: + """ + 将音频文件转换为 base64 编码 + """ + try: + # 读取音频文件并编码为 base64 + with open(audio_url, "rb") as f: + audio_bytes = f.read() + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + + # 确定音频格式 + extension = os.path.splitext(audio_url)[1].lower() + # 移除扩展名前面的点号 + audio_format = extension[1:] if extension.startswith(".") else extension + + logger.info( + f"音频文件转换成功: {audio_url},格式: {audio_format},大小: {len(audio_bytes)} 字节" + ) + return audio_base64, audio_format + except Exception as e: + logger.error(f"音频文件转换失败: {audio_url}, 错误: {e}") + return None, None + async def terminate(self): logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5399fbc32..434e891a8 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -226,6 +226,8 @@ async def _prepare_chat_payload( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -344,6 +346,8 @@ async def text_chat( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> LLMResponse: payloads, context_query, func_tool = await self._prepare_chat_payload( @@ -354,6 +358,8 @@ async def text_chat( contexts, system_prompt, tool_calls_result, + audio_urls, + video_urls, **kwargs, ) @@ -413,6 +419,8 @@ async def text_chat_stream( contexts=[], system_prompt=None, tool_calls_result=None, + audio_urls: List[str] = None, + video_urls: List[str] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 2f7490317..2fb6791c7 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -1,13 +1,11 @@ from astrbot.core.db import BaseDatabase from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse from .openai_source import ProviderOpenAIOfficial -@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器") +@register_provider_adapter("zhipu_chat_completion", "智谱 Chat Completion 提供商适配器") class ProviderZhipu(ProviderOpenAIOfficial): def __init__( self, @@ -27,12 +25,15 @@ def __init__( async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: List[str] = None, - func_tool: FuncCall = None, + prompt, + session_id=None, + image_urls=[], + func_tool=None, contexts=[], system_prompt=None, + tool_calls_result=None, + audio_urls=None, + video_urls=None, **kwargs, ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls)