|
9 | 9 | from agentmesh.protocal.context import TeamContext, AgentOutput |
10 | 10 | from agentmesh.protocal.result import AgentAction, AgentActionType, ToolResult, AgentResult |
11 | 11 | from agentmesh.tools.base_tool import BaseTool |
| 12 | +from agentmesh.tools.base_tool import ToolStage |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class Agent: |
@@ -48,10 +49,11 @@ def add_tool(self, tool: BaseTool): |
48 | 49 | self.tools.append(tool) |
49 | 50 |
|
50 | 51 | def _build_tools_prompt(self) -> str: |
51 | | - """Build the tool list description""" |
| 52 | + """Build the tool list description, only including pre-process tools""" |
52 | 53 | return "\n".join([ |
53 | 54 | f"{tool.name}: {tool.description} (parameters: {tool.params})" |
54 | 55 | for tool in self.tools |
| 56 | + if tool.stage == ToolStage.PRE_PROCESS # Only include pre-process tools |
55 | 57 | ]) |
56 | 58 |
|
57 | 59 | def _build_react_prompt(self) -> str: |
@@ -98,10 +100,19 @@ def _build_react_prompt(self) -> str: |
98 | 100 | return tools_prompt + ext_data_prompt + current_task_prompt |
99 | 101 |
|
100 | 102 | def _find_tool(self, tool_name: str): |
| 103 | + """Find and return a tool with the specified name""" |
101 | 104 | for tool in self.tools: |
102 | 105 | if tool.name == tool_name: |
103 | | - tool.model = self.model |
104 | | - return tool |
| 106 | + # Only pre-process stage tools can be actively called |
| 107 | + if tool.stage == ToolStage.PRE_PROCESS: |
| 108 | + tool.model = self.model |
| 109 | + tool.context = self # Set tool context |
| 110 | + return tool |
| 111 | + else: |
| 112 | + # If it's a post-process tool, return None to prevent direct calling |
| 113 | + logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.") |
| 114 | + return None |
| 115 | + return None |
105 | 116 |
|
106 | 117 | # output function based on mode |
107 | 118 | def output(self, message="", end="\n"): |
@@ -233,6 +244,16 @@ def step(self): |
233 | 244 | if "final_answer" in parsed and parsed["final_answer"] and parsed["final_answer"].lower() not in ["null", |
234 | 245 | "none"]: |
235 | 246 | final_answer = parsed["final_answer"] |
| 247 | + self.final_answer = final_answer |
| 248 | + |
| 249 | + # Store the final answer in team context |
| 250 | + self.team_context.agent_outputs.append( |
| 251 | + AgentOutput(agent_name=self.name, output=final_answer) |
| 252 | + ) |
| 253 | + |
| 254 | + # Execute all post-process tools |
| 255 | + self._execute_post_process_tools() |
| 256 | + |
236 | 257 | break |
237 | 258 |
|
238 | 259 | # Handle tool invocation |
@@ -273,19 +294,31 @@ def step(self): |
273 | 294 |
|
274 | 295 | current_step += 1 |
275 | 296 |
|
276 | | - # Save final result |
277 | | - result = final_answer if final_answer else raw_response |
278 | | - self.final_answer = result |
279 | | - self.team_context.agent_outputs.append( |
280 | | - AgentOutput(agent_name=self.name, output=result) |
281 | | - ) |
282 | | - |
283 | 297 | # Return a StepResult object |
284 | 298 | return AgentResult.success( |
285 | 299 | final_answer=self.final_answer, |
286 | 300 | step_count=current_step + 1 # +1 because we count steps starting from 1 |
287 | 301 | ) |
288 | 302 |
|
| 303 | + def _execute_post_process_tools(self): |
| 304 | + """Execute all post-process stage tools""" |
| 305 | + # Get all post-process stage tools |
| 306 | + post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS] |
| 307 | + |
| 308 | + # Execute each tool |
| 309 | + for tool in post_process_tools: |
| 310 | + # Set tool context |
| 311 | + tool.context = self |
| 312 | + |
| 313 | + # Execute tool (with empty parameters, tool will extract needed info from context) |
| 314 | + result = tool.execute({}) |
| 315 | + |
| 316 | + # Log result |
| 317 | + if result.status == "success": |
| 318 | + logger.info(f"Post-process tool {tool.name} executed successfully: {result.result.get('message', '')}") |
| 319 | + else: |
| 320 | + logger.warning(f"Post-process tool {tool.name} failed: {result.result}") |
| 321 | + |
289 | 322 | def should_invoke_next_agent(self) -> int: |
290 | 323 | """ |
291 | 324 | Determine if the next agent should be invoked based on the reply. |
|
0 commit comments