Skip to content

Commit 625bc9a

Browse files
authored
Merge branch 'jxxghp:v2' into v2
2 parents c4618b4 + 5e077cd commit 625bc9a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2667
-411
lines changed

app/agent/__init__.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
88
from langchain_community.callbacks import get_openai_callback
99
from langchain_core.chat_history import InMemoryChatMessageHistory
10-
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
10+
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
1111
from langchain_core.runnables.history import RunnableWithMessageHistory
1212

1313
from app.agent.callback import StreamingCallbackHandler
@@ -56,9 +56,6 @@ def __init__(self, session_id: str, user_id: str = None,
5656
# 工具
5757
self.tools = self._initialize_tools()
5858

59-
# 会话存储
60-
self.session_store = self._initialize_session_store()
61-
6259
# 提示词模板
6360
self.prompt = self._initialize_prompt()
6461

@@ -127,7 +124,8 @@ def _initialize_tools(self) -> List:
127124
channel=self.channel,
128125
source=self.source,
129126
username=self.username,
130-
callback_handler=self.callback_handler
127+
callback_handler=self.callback_handler,
128+
memory_mananger=self.memory_manager
131129
)
132130

133131
@staticmethod
@@ -137,34 +135,36 @@ def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
137135

138136
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
139137
"""获取会话历史"""
140-
if session_id not in self.session_store:
141-
chat_history = InMemoryChatMessageHistory()
142-
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
143-
session_id=session_id,
144-
user_id=self.user_id
145-
)
146-
if messages:
147-
for msg in messages:
148-
if msg.get("role") == "user":
149-
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
150-
elif msg.get("role") == "agent":
151-
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
152-
elif msg.get("role") == "tool_call":
153-
metadata = msg.get("metadata", {})
154-
chat_history.add_ai_message(AIMessage(
138+
chat_history = InMemoryChatMessageHistory()
139+
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
140+
session_id=session_id,
141+
user_id=self.user_id
142+
)
143+
if messages:
144+
for msg in messages:
145+
if msg.get("role") == "user":
146+
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
147+
elif msg.get("role") == "agent":
148+
chat_history.add_message(AIMessage(content=msg.get("content", "")))
149+
elif msg.get("role") == "tool_call":
150+
metadata = msg.get("metadata", {})
151+
chat_history.add_message(
152+
AIMessage(
155153
content=msg.get("content", ""),
156-
tool_calls=[ToolCall(
157-
id=metadata.get("call_id"),
158-
name=metadata.get("tool_name"),
159-
args=metadata.get("parameters"),
160-
)]
161-
))
162-
elif msg.get("role") == "tool_result":
163-
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
164-
elif msg.get("role") == "system":
165-
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
166-
self.session_store[session_id] = chat_history
167-
return self.session_store[session_id]
154+
tool_calls=[
155+
ToolCall(
156+
id=metadata.get("call_id"),
157+
name=metadata.get("tool_name"),
158+
args=metadata.get("parameters"),
159+
)
160+
]
161+
)
162+
)
163+
elif msg.get("role") == "tool_result":
164+
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
165+
elif msg.get("role") == "system":
166+
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
167+
return chat_history
168168

169169
@staticmethod
170170
def _initialize_prompt() -> ChatPromptTemplate:
@@ -306,8 +306,6 @@ async def send_agent_message(self, message: str, title: str = "MoviePilot助手"
306306

307307
async def cleanup(self):
308308
"""清理智能体资源"""
309-
if self.session_id in self.session_store:
310-
del self.session_store[self.session_id]
311309
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
312310

313311

app/agent/memory/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,27 @@ async def close(self):
4545

4646
logger.info("对话记忆管理器已关闭")
4747

48+
@staticmethod
49+
def get_memory_key(session_id: str, user_id: str):
50+
"""计算内存Key"""
51+
return f"{user_id}:{session_id}" if user_id else session_id
52+
53+
@staticmethod
54+
def get_redis_key(session_id: str, user_id: str):
55+
"""计算Redis Key"""
56+
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
57+
4858
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
4959
"""获取会话记忆"""
5060
# 首先检查缓存
51-
cache_key = f"{user_id}:{session_id}" if user_id else session_id
61+
cache_key = self.get_memory_key(session_id, user_id)
5262
if cache_key in self.memory_cache:
5363
return self.memory_cache[cache_key]
5464

5565
# 尝试从Redis加载
5666
if settings.CACHE_BACKEND_TYPE == "redis":
5767
try:
58-
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
68+
redis_key = self.get_redis_key(session_id, user_id)
5969
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
6070
if memory_data:
6171
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
@@ -180,15 +190,13 @@ def get_recent_messages_for_agent(
180190
181191
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
182192
"""
183-
cache_key = f"{user_id}:{session_id}" if user_id else session_id
193+
cache_key = self.get_memory_key(session_id, user_id)
184194
memory = self.memory_cache.get(cache_key)
185195
if not memory:
186196
return []
187197

188198
# 获取所有消息
189-
messages = memory.messages
190-
191-
return messages
199+
return memory.messages
192200

193201
async def get_recent_messages(
194202
self,
@@ -218,7 +226,7 @@ async def clear_memory(self, session_id: str, user_id: str):
218226
del self.memory_cache[cache_key]
219227

220228
if settings.CACHE_BACKEND_TYPE == "redis":
221-
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
229+
redis_key = self.get_redis_key(session_id, user_id)
222230
await self.redis_helper.delete(redis_key, region="AI_AGENT")
223231

224232
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
@@ -229,14 +237,14 @@ async def _save_memory(self, memory: ConversationMemory):
229237
Redis中的记忆会自动通过TTL机制过期,无需手动清理
230238
"""
231239
# 更新内存缓存
232-
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
240+
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
233241
self.memory_cache[cache_key] = memory
234242

235243
# 保存到Redis,设置TTL自动过期
236244
if settings.CACHE_BACKEND_TYPE == "redis":
237245
try:
238246
memory_dict = memory.model_dump()
239-
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
247+
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
240248
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
241249
await self.redis_helper.set(
242250
redis_key,

app/agent/tools/base.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""MoviePilot工具基类"""
2+
import json
23
from abc import ABCMeta, abstractmethod
3-
from typing import Callable, Any, Optional
4+
from typing import Any, Optional
45

56
from langchain.tools import BaseTool
67
from pydantic import PrivateAttr
78

8-
from app.agent import StreamingCallbackHandler
9+
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
910
from app.chain import ChainBase
1011
from app.log import logger
1112
from app.schemas import Notification
@@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
2425
_source: str = PrivateAttr(default=None)
2526
_username: str = PrivateAttr(default=None)
2627
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
28+
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
2729

2830
def __init__(self, session_id: str, user_id: str, **kwargs):
2931
super().__init__(**kwargs)
@@ -35,24 +37,53 @@ def _run(self, *args: Any, **kwargs: Any) -> Any:
3537

3638
async def _arun(self, **kwargs) -> str:
3739
"""异步运行工具"""
38-
# 发送运行工具前的消息
40+
# 发送和记忆工具调用前的信息
3941
agent_message = await self._callback_handler.get_message()
4042
if agent_message:
43+
# 发送消息
4144
await self.send_tool_message(agent_message, title="MoviePilot助手")
42-
# 发送执行工具说明
43-
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
45+
46+
# 记忆工具调用
47+
await self._memory_manager.add_memory(
48+
session_id=self._session_id,
49+
user_id=self._user_id,
50+
role="tool_call",
51+
content=agent_message,
52+
metadata={
53+
"call_id": self.__class__.__name__,
54+
"tool_name": self.__class__.__name__,
55+
"parameters": kwargs
56+
}
57+
)
58+
59+
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
4460
tool_message = self.get_tool_message(**kwargs)
4561
if not tool_message:
4662
explanation = kwargs.get("explanation")
4763
if explanation:
4864
tool_message = explanation
49-
5065
if tool_message:
5166
formatted_message = f"⚙️ => {tool_message}"
5267
await self.send_tool_message(formatted_message)
68+
5369
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
5470
result = await self.run(**kwargs)
5571
logger.debug(f'Tool {self.name} executed with result: {result}')
72+
73+
# 记忆工具调用结果
74+
if isinstance(result, str):
75+
formated_result = result
76+
elif isinstance(result, int, float):
77+
formated_result = str(result)
78+
else:
79+
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
80+
await self._memory_manager.add_memory(
81+
session_id=self._session_id,
82+
user_id=self._user_id,
83+
role="tool_result",
84+
content=formated_result
85+
)
86+
5687
return result
5788

5889
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -84,6 +115,10 @@ def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
84115
"""设置回调处理器"""
85116
self._callback_handler = callback_handler
86117

118+
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
119+
"""设置记忆客理器"""
120+
self._memory_manager = memory_manager
121+
87122
async def send_tool_message(self, message: str, title: str = ""):
88123
"""发送工具消息"""
89124
await ToolChain().async_post_message(

app/agent/tools/factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class MoviePilotToolFactory:
5151
@staticmethod
5252
def create_tools(session_id: str, user_id: str,
5353
channel: str = None, source: str = None, username: str = None,
54-
callback_handler: Callable = None) -> List[MoviePilotTool]:
54+
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
5555
"""创建MoviePilot工具列表"""
5656
tools = []
5757
tool_definitions = [
@@ -102,6 +102,7 @@ def create_tools(session_id: str, user_id: str,
102102
)
103103
tool.set_message_attr(channel=channel, source=source, username=username)
104104
tool.set_callback_handler(callback_handler=callback_handler)
105+
tool.set_memory_manager(memory_manager=memory_mananger)
105106
tools.append(tool)
106107

107108
# 加载插件提供的工具
@@ -124,6 +125,7 @@ def create_tools(session_id: str, user_id: str,
124125
)
125126
tool.set_message_attr(channel=channel, source=source, username=username)
126127
tool.set_callback_handler(callback_handler=callback_handler)
128+
tool.set_memory_manager(memory_manager=memory_mananger)
127129
tools.append(tool)
128130
plugin_tools_count += 1
129131
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")

app/agent/tools/impl/query_download_tasks.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""查询下载工具"""
22

33
import json
4-
from typing import Optional, Type
4+
from typing import Optional, Type, List, Union
55

66
from pydantic import BaseModel, Field
77

88
from app.agent.tools.base import MoviePilotTool
99
from app.chain.download import DownloadChain
1010
from app.db.downloadhistory_oper import DownloadHistoryOper
1111
from app.log import logger
12+
from app.schemas import TransferTorrent, DownloadingTorrent
13+
from app.schemas.types import TorrentStatus
1214

1315

1416
class QueryDownloadTasksInput(BaseModel):
@@ -27,6 +29,27 @@ class QueryDownloadTasksTool(MoviePilotTool):
2729
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
2830
args_schema: Type[BaseModel] = QueryDownloadTasksInput
2931

32+
def _get_all_torrents(self, download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
33+
"""
34+
查询所有状态的任务(包括下载中和已完成的任务)
35+
"""
36+
all_torrents = []
37+
# 查询正在下载的任务
38+
downloading_torrents = download_chain.list_torrents(
39+
downloader=downloader,
40+
status=TorrentStatus.DOWNLOADING
41+
) or []
42+
all_torrents.extend(downloading_torrents)
43+
44+
# 查询已完成的任务(可转移状态)
45+
transfer_torrents = download_chain.list_torrents(
46+
downloader=downloader,
47+
status=TorrentStatus.TRANSFER
48+
) or []
49+
all_torrents.extend(transfer_torrents)
50+
51+
return all_torrents
52+
3053
def get_tool_message(self, **kwargs) -> Optional[str]:
3154
"""根据查询参数生成友好的提示消息"""
3255
downloader = kwargs.get("downloader")
@@ -60,7 +83,7 @@ async def run(self, downloader: Optional[str] = None,
6083

6184
# 如果提供了hash,直接查询该hash的任务(不限制状态)
6285
if hash:
63-
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
86+
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
6487
if not torrents:
6588
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
6689
# 转换为DownloadingTorrent格式
@@ -84,14 +107,25 @@ async def run(self, downloader: Optional[str] = None,
84107
elif title:
85108
# 如果提供了title,查询所有任务并搜索匹配的标题
86109
# 查询所有状态的任务
87-
all_torrents = download_chain.list_torrents(downloader=downloader) or []
110+
all_torrents = self._get_all_torrents(download_chain, downloader)
88111
filtered_downloads = []
112+
title_lower = title.lower()
89113
for torrent in all_torrents:
90-
# 检查标题或名称是否匹配
91-
if (title.lower() in (torrent.title or "").lower()) or \
92-
(title.lower() in (torrent.name or "").lower()):
93-
# 获取下载历史信息
94-
history = DownloadHistoryOper().get_by_hash(torrent.hash)
114+
# 获取下载历史信息
115+
history = DownloadHistoryOper().get_by_hash(torrent.hash)
116+
117+
# 检查标题或名称是否匹配(包括下载历史中的标题)
118+
matched = False
119+
# 检查torrent的title和name字段
120+
if (title_lower in (torrent.title or "").lower()) or \
121+
(title_lower in (torrent.name or "").lower()):
122+
matched = True
123+
# 检查下载历史中的标题
124+
if history and history.title:
125+
if title_lower in history.title.lower():
126+
matched = True
127+
128+
if matched:
95129
if history:
96130
torrent.media = {
97131
"tmdbid": history.tmdbid,
@@ -110,7 +144,7 @@ async def run(self, downloader: Optional[str] = None,
110144
# 根据status决定查询方式
111145
if status == "downloading":
112146
# 如果status为下载中,使用downloading方法
113-
downloads = download_chain.downloading(name=downloader)
147+
downloads = download_chain.downloading(name=downloader) or []
114148
filtered_downloads = []
115149
for dl in downloads:
116150
if downloader and dl.downloader != downloader:
@@ -119,7 +153,7 @@ async def run(self, downloader: Optional[str] = None,
119153
else:
120154
# 其他状态(completed、paused、all),使用list_torrents查询所有任务
121155
# 查询所有状态的任务
122-
all_torrents = download_chain.list_torrents(downloader=downloader) or []
156+
all_torrents = self._get_all_torrents(download_chain, downloader)
123157
filtered_downloads = []
124158
for torrent in all_torrents:
125159
if downloader and torrent.downloader != downloader:

0 commit comments

Comments
 (0)