77 CODER_PROMPT ,
88 MODELER_PROMPT ,
99)
10- from app .core .functions import tools
10+ from app .core .functions import coder_tools , writer_tools
1111from app .models .model import CoderToWriter
1212from app .models .user_output import UserOutput
1313from app .utils .enums import CompTemplate , FormatOutPut
1717from app .utils .redis_manager import redis_manager
1818from app .schemas .response import SystemMessage
1919from app .tools .base_interpreter import BaseCodeInterpreter
20+ from app .tools .openalex_scholar import OpenAlexScholar
2021
2122
2223class Agent :
@@ -26,7 +27,7 @@ def __init__(
2627 model : LLM ,
2728 max_chat_turns : int = 30 , # 单个agent最大对话轮次
2829 user_output : UserOutput = None ,
29- max_memory : int = 20 , # 最大记忆轮次
30+ max_memory : int = 25 , # 最大记忆轮次
3031 ) -> None :
3132 self .task_id = task_id
3233 self .model = model
@@ -85,7 +86,7 @@ def clear_memory(self):
8586 self .chat_history = self .chat_history [:2 ] + self .chat_history [- 5 :]
8687
8788
88- class ModelerAgent (Agent ): # 继承自Agent类而不是BaseModel
89+ class ModelerAgent (Agent ): # 继承自Agent类
8990 def __init__ (
9091 self ,
9192 model : LLM ,
@@ -168,7 +169,7 @@ async def run(self, prompt: str, subtask_title: str) -> CoderToWriter:
168169 logger .info (f"当前对话轮次: { self .current_chat_turns } " )
169170 response = await self .model .chat (
170171 history = self .chat_history ,
171- tools = tools ,
172+ tools = coder_tools ,
172173 tool_choice = "auto" ,
173174 agent_name = self .__class__ .__name__ ,
174175 )
@@ -274,7 +275,7 @@ async def run(self, prompt: str, subtask_title: str) -> CoderToWriter:
274275
275276 completion_response = await self .model .chat (
276277 history = self .chat_history ,
277- tools = tools ,
278+ tools = coder_tools ,
278279 tool_choice = "auto" ,
279280 agent_name = self .__class__ .__name__ ,
280281 )
@@ -318,10 +319,12 @@ def __init__(
318319 comp_template : CompTemplate = CompTemplate ,
319320 format_output : FormatOutPut = FormatOutPut .Markdown ,
320321 user_output : UserOutput = None ,
322+ scholar : OpenAlexScholar = None ,
321323 ) -> None :
322324 super ().__init__ (task_id , model , max_chat_turns , user_output )
323325 self .format_out_put = format_output
324326 self .comp_template = comp_template
327+ self .scholar = scholar
325328 self .system_prompt = get_writer_prompt (format_output )
326329 self .available_images : list [str ] = []
327330
@@ -347,28 +350,86 @@ async def run(
347350 image_prompt = f"\n 可用的图片链接列表:\n { image_list } \n 请在写作时适当引用这些图片链接。"
348351 prompt = prompt + image_prompt
349352
350- try :
351- logger .info (f"{ self .__class__ .__name__ } :开始:执行对话" )
352- self .current_chat_turns = 0 # 重置对话轮次计数器
353+ logger .info (f"{ self .__class__ .__name__ } :开始:执行对话" )
354+ self .current_chat_turns += 1 # 重置对话轮次计数器
353355
354- # 更新对话历史
355- self .append_chat_history ({"role" : "system" , "content" : self .system_prompt })
356- self .append_chat_history ({"role" : "user" , "content" : prompt })
356+ # 更新对话历史
357+ self .append_chat_history ({"role" : "system" , "content" : self .system_prompt })
358+ self .append_chat_history ({"role" : "user" , "content" : prompt })
357359
358- # 获取历史消息用于本次对话
359- response = await self .model .chat (
360- history = self .chat_history ,
361- agent_name = self .__class__ .__name__ ,
362- sub_title = sub_title ,
363- )
360+ # 获取历史消息用于本次对话
361+ response = await self .model .chat (
362+ history = self .chat_history ,
363+ tools = writer_tools ,
364+ tool_choice = "auto" ,
365+ agent_name = self .__class__ .__name__ ,
366+ sub_title = sub_title ,
367+ )
368+
369+ if (
370+ hasattr (response .choices [0 ].message , "tool_calls" )
371+ and response .choices [0 ].message .tool_calls
372+ ):
373+ logger .info ("检测到工具调用" )
374+ tool_call = response .choices [0 ].message .tool_calls [0 ]
375+ tool_id = tool_call .id
376+ tool_call .function .name
377+ if tool_call .function .name == "search_papers" :
378+ logger .info ("调用工具: search_papers" )
379+ await redis_manager .publish_message (
380+ self .task_id ,
381+ SystemMessage (content = f"写作手调用{ tool_call .function .name } 工具" ),
382+ )
383+
384+ query = json .loads (tool_call .function .arguments )["query" ]
385+
386+ full_content = response .choices [0 ].message .content
387+ # 更新对话历史 - 添加助手的响应
388+ self .append_chat_history (
389+ {
390+ "role" : "assistant" ,
391+ "content" : full_content ,
392+ "tool_calls" : [
393+ {
394+ "id" : tool_id ,
395+ "type" : "function" ,
396+ "function" : {
397+ "name" : "search_papers" ,
398+ "arguments" : json .dumps ({"query" : query }),
399+ },
400+ }
401+ ],
402+ }
403+ )
404+
405+ try :
406+ papers = self .scholar .search_papers (query )
407+ except Exception as e :
408+ logger .error (f"搜索文献失败: { str (e )} " )
409+ return f"搜索文献失败: { str (e )} "
410+ # TODO: pass to frontend
411+ self .scholar .print_papers (papers )
412+ self .append_chat_history (
413+ {
414+ "role" : "tool" ,
415+ "content" : papers ,
416+ "tool_call_id" : tool_id ,
417+ "name" : "search_papers" ,
418+ }
419+ )
420+ next_response = await self .model .chat (
421+ history = self .chat_history ,
422+ tools = writer_tools ,
423+ tool_choice = "auto" ,
424+ agent_name = self .__class__ .__name__ ,
425+ sub_title = sub_title ,
426+ )
427+ response_content = next_response .choices [0 ].message .content
428+ else :
364429 response_content = response .choices [0 ].message .content
365- self .chat_history .append ({"role" : "assistant" , "content" : response_content })
366- logger .info (f"{ self .__class__ .__name__ } :完成:执行对话" )
367- return response_content
368- except Exception as e :
369- error_msg = f"执行过程中遇到错误: { str (e )} "
370- logger .error (f"Agent执行失败: { str (e )} " )
371- return error_msg
430+ self .chat_history .append ({"role" : "assistant" , "content" : response_content })
431+ logger .info (f"{ self .__class__ .__name__ } :完成:执行对话" )
432+ return response_content
372433
373434 async def summarize (self ) -> str :
374435 """
0 commit comments