Skip to content

Commit 4a1b67a

Browse files
moonbox3alliscode
authored andcommitted
Python: fix(ag-ui): add MCP tool support for AG-UI approval flows (microsoft#3212)
* add MCP tool support for AG-UI approval flows * use attribute in place of property
1 parent 7ac6c10 commit 4a1b67a

File tree

7 files changed

+236
-113
lines changed

7 files changed

+236
-113
lines changed

python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,59 +3,85 @@
33
"""Tool handling helpers."""
44

55
import logging
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77

8-
from agent_framework import BaseChatClient, ChatAgent
8+
from agent_framework import BaseChatClient
9+
10+
if TYPE_CHECKING:
11+
from agent_framework import AgentProtocol
912

1013
logger = logging.getLogger(__name__)
1114

1215

13-
def collect_server_tools(agent: Any) -> list[Any]:
14-
"""Collect server tools from ChatAgent or duck-typed agent."""
15-
if isinstance(agent, ChatAgent):
16-
tools_from_agent = agent.default_options.get("tools")
17-
server_tools = list(tools_from_agent) if tools_from_agent else []
18-
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
19-
for tool in server_tools:
20-
tool_name = getattr(tool, "name", "unknown")
21-
approval_mode = getattr(tool, "approval_mode", None)
22-
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
23-
return server_tools
24-
25-
try:
26-
default_options_attr = getattr(agent, "default_options", None)
27-
if default_options_attr is not None:
28-
if isinstance(default_options_attr, dict):
29-
return default_options_attr.get("tools") or []
30-
return getattr(default_options_attr, "tools", None) or []
31-
except AttributeError:
16+
def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
17+
"""Extract functions from connected MCP tools.
18+
19+
Args:
20+
mcp_tools: List of MCP tool instances.
21+
22+
Returns:
23+
List of functions from connected MCP tools.
24+
"""
25+
functions: list[Any] = []
26+
for mcp_tool in mcp_tools:
27+
if getattr(mcp_tool, "is_connected", False) and hasattr(mcp_tool, "functions"):
28+
functions.extend(mcp_tool.functions)
29+
return functions
30+
31+
32+
def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
33+
"""Collect server tools from an agent.
34+
35+
This includes both regular tools from default_options and MCP tools.
36+
MCP tools are stored separately for lifecycle management but their
37+
functions need to be included for tool execution during approval flows.
38+
39+
Args:
40+
agent: Agent instance to collect tools from. Works with ChatAgent
41+
or any agent with default_options and optional mcp_tools attributes.
42+
43+
Returns:
44+
List of tools including both regular tools and connected MCP tool functions.
45+
"""
46+
# Get tools from default_options
47+
default_options = getattr(agent, "default_options", None)
48+
if default_options is None:
3249
return []
33-
return []
3450

51+
tools_from_agent = default_options.get("tools") if isinstance(default_options, dict) else None
52+
server_tools = list(tools_from_agent) if tools_from_agent else []
53+
54+
# Include functions from connected MCP tools (only available on ChatAgent)
55+
mcp_tools = getattr(agent, "mcp_tools", None)
56+
if mcp_tools:
57+
server_tools.extend(_collect_mcp_tool_functions(mcp_tools))
58+
59+
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
60+
for tool in server_tools:
61+
tool_name = getattr(tool, "name", "unknown")
62+
approval_mode = getattr(tool, "approval_mode", None)
63+
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
64+
return server_tools
3565

36-
def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None:
37-
"""Register client tools as additional declaration-only tools to avoid server execution."""
66+
67+
def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None:
68+
"""Register client tools as additional declaration-only tools to avoid server execution.
69+
70+
Args:
71+
agent: Agent instance to register tools on. Works with ChatAgent
72+
or any agent with a chat_client attribute.
73+
client_tools: List of client tools to register.
74+
"""
3875
if not client_tools:
3976
return
4077

41-
if isinstance(agent, ChatAgent):
42-
chat_client = agent.chat_client
43-
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
44-
chat_client.function_invocation_configuration.additional_tools = client_tools
45-
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
78+
chat_client = getattr(agent, "chat_client", None)
79+
if chat_client is None:
4680
return
4781

48-
try:
49-
chat_client_attr = getattr(agent, "chat_client", None)
50-
if chat_client_attr is not None:
51-
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
52-
if fic is not None:
53-
fic.additional_tools = client_tools # type: ignore[attr-defined]
54-
logger.debug(
55-
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
56-
)
57-
except AttributeError:
58-
return
82+
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
83+
chat_client.function_invocation_configuration.additional_tools = client_tools
84+
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
5985

6086

6187
def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None:

python/packages/ag-ui/tests/test_orchestrators.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
"""Tests for AG-UI orchestrators."""
44

55
from collections.abc import AsyncGenerator
6-
from types import SimpleNamespace
76
from typing import Any
7+
from unittest.mock import MagicMock
88

9-
from agent_framework import AgentResponseUpdate, FunctionInvocationConfiguration, TextContent, ai_function
9+
from agent_framework import (
10+
AgentResponseUpdate,
11+
BaseChatClient,
12+
ChatAgent,
13+
FunctionInvocationConfiguration,
14+
TextContent,
15+
ai_function,
16+
)
1017

1118
from agent_framework_ag_ui._agent import AgentConfig
1219
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
@@ -18,56 +25,53 @@ def server_tool() -> str:
1825
return "server"
1926

2027

21-
class DummyAgent:
22-
"""Minimal agent stub to capture run_stream parameters."""
23-
24-
def __init__(self) -> None:
25-
self.default_options: dict[str, Any] = {"tools": [server_tool], "response_format": None}
26-
self.tools = [server_tool]
27-
self.chat_client = SimpleNamespace(
28-
function_invocation_configuration=FunctionInvocationConfiguration(),
29-
)
30-
self.seen_tools: list[Any] | None = None
28+
def _create_mock_chat_agent(
29+
tools: list[Any] | None = None,
30+
response_format: Any = None,
31+
capture_tools: list[Any] | None = None,
32+
capture_messages: list[Any] | None = None,
33+
) -> ChatAgent:
34+
"""Create a ChatAgent with mocked chat client for testing.
35+
36+
Args:
37+
tools: Tools to configure on the agent.
38+
response_format: Response format to configure.
39+
capture_tools: If provided, tools passed to run_stream will be appended here.
40+
capture_messages: If provided, messages passed to run_stream will be appended here.
41+
"""
42+
mock_chat_client = MagicMock(spec=BaseChatClient)
43+
mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration()
44+
45+
agent = ChatAgent(
46+
chat_client=mock_chat_client,
47+
tools=tools or [server_tool],
48+
response_format=response_format,
49+
)
3150

32-
async def run_stream(
33-
self,
51+
# Create a mock run_stream that captures parameters and yields a simple response
52+
async def mock_run_stream(
3453
messages: list[Any],
3554
*,
36-
thread: Any,
55+
thread: Any = None,
3756
tools: list[Any] | None = None,
3857
**kwargs: Any,
3958
) -> AsyncGenerator[AgentResponseUpdate, None]:
40-
self.seen_tools = tools
59+
if capture_tools is not None and tools is not None:
60+
capture_tools.extend(tools)
61+
if capture_messages is not None:
62+
capture_messages.extend(messages)
4163
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
4264

65+
# Patch the run_stream method
66+
agent.run_stream = mock_run_stream # type: ignore[method-assign]
4367

44-
class RecordingAgent:
45-
"""Agent stub that captures messages passed to run_stream."""
46-
47-
def __init__(self) -> None:
48-
self.chat_options = SimpleNamespace(tools=[], response_format=None)
49-
self.tools: list[Any] = []
50-
self.chat_client = SimpleNamespace(
51-
function_invocation_configuration=FunctionInvocationConfiguration(),
52-
)
53-
self.seen_messages: list[Any] | None = None
54-
55-
async def run_stream(
56-
self,
57-
messages: list[Any],
58-
*,
59-
thread: Any,
60-
tools: list[Any] | None = None,
61-
**kwargs: Any,
62-
) -> AsyncGenerator[AgentResponseUpdate, None]:
63-
self.seen_messages = messages
64-
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
68+
return agent
6569

6670

6771
async def test_default_orchestrator_merges_client_tools() -> None:
6872
"""Client tool declarations are merged with server tools before running agent."""
69-
70-
agent = DummyAgent()
73+
captured_tools: list[Any] = []
74+
agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools)
7175
orchestrator = DefaultOrchestrator()
7276

7377
input_data = {
@@ -100,17 +104,16 @@ async def test_default_orchestrator_merges_client_tools() -> None:
100104
async for event in orchestrator.run(context):
101105
events.append(event)
102106

103-
assert agent.seen_tools is not None
104-
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
107+
assert len(captured_tools) > 0
108+
tool_names = [getattr(tool, "name", "?") for tool in captured_tools]
105109
assert "server_tool" in tool_names
106110
assert "get_weather" in tool_names
107111
assert agent.chat_client.function_invocation_configuration.additional_tools
108112

109113

110114
async def test_default_orchestrator_with_camel_case_ids() -> None:
111115
"""Client tool is able to extract camelCase IDs."""
112-
113-
agent = DummyAgent()
116+
agent = _create_mock_chat_agent()
114117
orchestrator = DefaultOrchestrator()
115118

116119
input_data = {
@@ -143,8 +146,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None:
143146

144147
async def test_default_orchestrator_with_snake_case_ids() -> None:
145148
"""Client tool is able to extract snake_case IDs."""
146-
147-
agent = DummyAgent()
149+
agent = _create_mock_chat_agent()
148150
orchestrator = DefaultOrchestrator()
149151

150152
input_data = {
@@ -177,8 +179,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None:
177179

178180
async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
179181
"""State context should be injected when current state differs from tool call args."""
180-
181-
agent = RecordingAgent()
182+
captured_messages: list[Any] = []
183+
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
182184
orchestrator = DefaultOrchestrator()
183185

184186
tool_recipe = {"title": "Salad", "special_preferences": []}
@@ -215,9 +217,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
215217
async for _event in orchestrator.run(context):
216218
pass
217219

218-
assert agent.seen_messages is not None
220+
assert len(captured_messages) > 0
219221
state_messages = []
220-
for msg in agent.seen_messages:
222+
for msg in captured_messages:
221223
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
222224
if role_value != "system":
223225
continue
@@ -230,8 +232,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
230232

231233
async def test_state_context_not_injected_when_tool_call_matches_state() -> None:
232234
"""State context should be skipped when tool call args match current state."""
233-
234-
agent = RecordingAgent()
235+
captured_messages: list[Any] = []
236+
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
235237
orchestrator = DefaultOrchestrator()
236238

237239
input_data = {
@@ -264,9 +266,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
264266
async for _event in orchestrator.run(context):
265267
pass
266268

267-
assert agent.seen_messages is not None
269+
assert len(captured_messages) > 0
268270
state_messages = []
269-
for msg in agent.seen_messages:
271+
for msg in captured_messages:
270272
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
271273
if role_value != "system":
272274
continue

0 commit comments

Comments
 (0)