Skip to content

Commit dc753f2

Browse files
committed
Add support for using MCP tools in ReAct agents
1 parent 15be2cc commit dc753f2

File tree

4 files changed

+84
-24
lines changed

4 files changed

+84
-24
lines changed

coagent/agents/react_agent/agent.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
BaseAgent,
1111
Context,
1212
handler,
13+
logger,
1314
)
1415
from coagent.core.util import get_func_args, pretty_trace_tool_call
16+
import mcputil
1517
from openai.types.responses import (
1618
EasyInputMessageParam,
1719
ResponseOutputText,
@@ -95,6 +97,33 @@ def model(self) -> Model:
9597
def model_settings(self) -> ModelSettings:
9698
return self._model_settings
9799

100+
async def started(self) -> None:
101+
# Extract MCP clients from tools list
102+
mcp_clients = [tool for tool in self._tools if isinstance(tool, mcputil.Client)]
103+
104+
# Filter out MCP clients from tools list
105+
self._tools = [
106+
tool for tool in self._tools if not isinstance(tool, mcputil.Client)
107+
]
108+
109+
# Load tools from all MCP clients concurrently
110+
async def get_client_tools(client):
111+
try:
112+
return await client.get_tools()
113+
except Exception as exc:
114+
# Log error but continue with empty tools list
115+
logger.error(f"Error getting tools from MCP client: {exc}")
116+
return []
117+
118+
# Fetch all tools concurrently and flatten the results
119+
all_mcp_tools = await asyncio.gather(
120+
*[get_client_tools(client) for client in mcp_clients]
121+
)
122+
123+
# Add all tools to the tools list
124+
for mcp_tools in all_mcp_tools:
125+
self._tools.extend(mcp_tools)
126+
98127
@handler
99128
async def handle_history(
100129
self, msg: InputHistory, ctx: Context
@@ -312,32 +341,58 @@ async def handle_function_call(
312341
args = {k: v for k, v in args.items() if k in want_arg_names}
313342
pretty_trace_tool_call(f"Actual Call: {name}", args)
314343

344+
tool_ctx = RunContext.with_tool(
345+
ctx,
346+
name=function_call.name,
347+
call_id=function_call.call_id,
348+
arguments=function_call.arguments,
349+
)
315350
# TODO: Check by argument types instead of names. E.g. `ctx` could be a `RunContext`.
316351
if __CTX_VARS_NAME__ in want_arg_names:
317-
args[__CTX_VARS_NAME__] = RunContext.with_tool(
318-
ctx,
319-
name=function_call.name,
320-
call_id=function_call.call_id,
321-
arguments=function_call.arguments,
322-
)
352+
args[__CTX_VARS_NAME__] = tool_ctx
323353

324-
raw_result = func(**args)
325-
if inspect.isawaitable(raw_result):
326-
result = await raw_result
354+
if isinstance(func, mcputil.Tool):
355+
result: mcputil.Result = await func.call(
356+
call_id=function_call.call_id, **args
357+
)
358+
async for event in result.events():
359+
if isinstance(event, mcputil.ProgressEvent):
360+
# Report progress to the context.
361+
tool_ctx.report_progress(
362+
progress=event.progress or 0,
363+
total=event.total or 0,
364+
message=event.message or "",
365+
)
366+
elif isinstance(event, mcputil.OutputEvent):
367+
return ToolCallOutputItem(
368+
raw_item=ResponseFunctionToolCallOutputItem(
369+
id="",
370+
call_id=function_call.call_id,
371+
output=str(event.output),
372+
type="function_call_output",
373+
status="completed",
374+
),
375+
output=event.output,
376+
type="tool_call_output_item",
377+
)
327378
else:
328-
result = raw_result
329-
330-
return ToolCallOutputItem(
331-
raw_item=ResponseFunctionToolCallOutputItem(
332-
id="",
333-
call_id=function_call.call_id,
334-
output=str(result),
335-
type="function_call_output",
336-
status="completed",
337-
),
338-
output=result,
339-
type="tool_call_output_item",
340-
)
379+
raw_result = func(**args)
380+
if inspect.isawaitable(raw_result):
381+
result = await raw_result
382+
else:
383+
result = raw_result
384+
385+
return ToolCallOutputItem(
386+
raw_item=ResponseFunctionToolCallOutputItem(
387+
id="",
388+
call_id=function_call.call_id,
389+
output=str(result),
390+
type="function_call_output",
391+
status="completed",
392+
),
393+
output=result,
394+
type="tool_call_output_item",
395+
)
341396

342397
async def get_chat_completion(
343398
self,

coagent/agents/react_agent/util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import typing
44

5+
import mcputil
56
from pydantic import Field, create_model
67
from pydantic.fields import FieldInfo
78

@@ -139,7 +140,10 @@ def greet(
139140
# Construct the pydantic mdoel for the _under_fn's function signature parameters.
140141
# 1. Get the function signature.
141142

142-
sig = inspect.signature(func)
143+
if isinstance(func, mcputil.Tool):
144+
sig = func.__sig__
145+
else:
146+
sig = inspect.signature(func)
143147

144148
# 2. Create a dictionary of field definitions for the Pydantic model
145149
fields = {}

examples/patterns/autonomous_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def main():
5151

5252
result = await reporter.run(
5353
InputHistory(
54-
messages=[InputMessage(role="user", content="What's the weather?")]
54+
messages=[InputMessage(role="user", content="What's the weather like?")]
5555
).encode(),
5656
stream=True,
5757
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ litellm = "1.60.4"
3838
mcp = ">=1.19.0"
3939
jinja2 = "3.1.5"
4040
aiorwlock = ">=1.5.0"
41+
mcputil = "0.1.0"
4142

4243
# A list of optional dependencies, which are included in the
4344
# below `extras`.

0 commit comments

Comments
 (0)