|
10 | 10 | BaseAgent, |
11 | 11 | Context, |
12 | 12 | handler, |
| 13 | + logger, |
13 | 14 | ) |
14 | 15 | from coagent.core.util import get_func_args, pretty_trace_tool_call |
| 16 | +import mcputil |
15 | 17 | from openai.types.responses import ( |
16 | 18 | EasyInputMessageParam, |
17 | 19 | ResponseOutputText, |
@@ -95,6 +97,33 @@ def model(self) -> Model: |
95 | 97 | def model_settings(self) -> ModelSettings: |
96 | 98 | return self._model_settings |
97 | 99 |
|
| 100 | + async def started(self) -> None: |
| 101 | + # Extract MCP clients from tools list |
| 102 | + mcp_clients = [tool for tool in self._tools if isinstance(tool, mcputil.Client)] |
| 103 | + |
| 104 | + # Filter out MCP clients from tools list |
| 105 | + self._tools = [ |
| 106 | + tool for tool in self._tools if not isinstance(tool, mcputil.Client) |
| 107 | + ] |
| 108 | + |
| 109 | + # Load tools from all MCP clients concurrently |
| 110 | + async def get_client_tools(client): |
| 111 | + try: |
| 112 | + return await client.get_tools() |
| 113 | + except Exception as exc: |
| 114 | + # Log error but continue with empty tools list |
| 115 | + logger.error(f"Error getting tools from MCP client: {exc}") |
| 116 | + return [] |
| 117 | + |
| 118 | + # Fetch all tools concurrently and flatten the results |
| 119 | + all_mcp_tools = await asyncio.gather( |
| 120 | + *[get_client_tools(client) for client in mcp_clients] |
| 121 | + ) |
| 122 | + |
| 123 | + # Add all tools to the tools list |
| 124 | + for mcp_tools in all_mcp_tools: |
| 125 | + self._tools.extend(mcp_tools) |
| 126 | + |
98 | 127 | @handler |
99 | 128 | async def handle_history( |
100 | 129 | self, msg: InputHistory, ctx: Context |
@@ -312,32 +341,58 @@ async def handle_function_call( |
312 | 341 | args = {k: v for k, v in args.items() if k in want_arg_names} |
313 | 342 | pretty_trace_tool_call(f"Actual Call: {name}", args) |
314 | 343 |
|
| 344 | + tool_ctx = RunContext.with_tool( |
| 345 | + ctx, |
| 346 | + name=function_call.name, |
| 347 | + call_id=function_call.call_id, |
| 348 | + arguments=function_call.arguments, |
| 349 | + ) |
315 | 350 | # TODO: Check by argument types instead of names. E.g. `ctx` could be a `RunContext`. |
316 | 351 | if __CTX_VARS_NAME__ in want_arg_names: |
317 | | - args[__CTX_VARS_NAME__] = RunContext.with_tool( |
318 | | - ctx, |
319 | | - name=function_call.name, |
320 | | - call_id=function_call.call_id, |
321 | | - arguments=function_call.arguments, |
322 | | - ) |
| 352 | + args[__CTX_VARS_NAME__] = tool_ctx |
323 | 353 |
|
324 | | - raw_result = func(**args) |
325 | | - if inspect.isawaitable(raw_result): |
326 | | - result = await raw_result |
| 354 | + if isinstance(func, mcputil.Tool): |
| 355 | + result: mcputil.Result = await func.call( |
| 356 | + call_id=function_call.call_id, **args |
| 357 | + ) |
| 358 | + async for event in result.events(): |
| 359 | + if isinstance(event, mcputil.ProgressEvent): |
| 360 | + # Report progress to the context. |
| 361 | + tool_ctx.report_progress( |
| 362 | + progress=event.progress or 0, |
| 363 | + total=event.total or 0, |
| 364 | + message=event.message or "", |
| 365 | + ) |
| 366 | + elif isinstance(event, mcputil.OutputEvent): |
| 367 | + return ToolCallOutputItem( |
| 368 | + raw_item=ResponseFunctionToolCallOutputItem( |
| 369 | + id="", |
| 370 | + call_id=function_call.call_id, |
| 371 | + output=str(event.output), |
| 372 | + type="function_call_output", |
| 373 | + status="completed", |
| 374 | + ), |
| 375 | + output=event.output, |
| 376 | + type="tool_call_output_item", |
| 377 | + ) |
327 | 378 | else: |
328 | | - result = raw_result |
329 | | - |
330 | | - return ToolCallOutputItem( |
331 | | - raw_item=ResponseFunctionToolCallOutputItem( |
332 | | - id="", |
333 | | - call_id=function_call.call_id, |
334 | | - output=str(result), |
335 | | - type="function_call_output", |
336 | | - status="completed", |
337 | | - ), |
338 | | - output=result, |
339 | | - type="tool_call_output_item", |
340 | | - ) |
| 379 | + raw_result = func(**args) |
| 380 | + if inspect.isawaitable(raw_result): |
| 381 | + result = await raw_result |
| 382 | + else: |
| 383 | + result = raw_result |
| 384 | + |
| 385 | + return ToolCallOutputItem( |
| 386 | + raw_item=ResponseFunctionToolCallOutputItem( |
| 387 | + id="", |
| 388 | + call_id=function_call.call_id, |
| 389 | + output=str(result), |
| 390 | + type="function_call_output", |
| 391 | + status="completed", |
| 392 | + ), |
| 393 | + output=result, |
| 394 | + type="tool_call_output_item", |
| 395 | + ) |
341 | 396 |
|
342 | 397 | async def get_chat_completion( |
343 | 398 | self, |
|
0 commit comments