33"""Tests for AG-UI orchestrators."""
44
55from collections .abc import AsyncGenerator
6- from types import SimpleNamespace
76from typing import Any
7+ from unittest .mock import MagicMock
88
9- from agent_framework import AgentResponseUpdate , FunctionInvocationConfiguration , TextContent , ai_function
9+ from agent_framework import (
10+ AgentResponseUpdate ,
11+ BaseChatClient ,
12+ ChatAgent ,
13+ FunctionInvocationConfiguration ,
14+ TextContent ,
15+ ai_function ,
16+ )
1017
1118from agent_framework_ag_ui ._agent import AgentConfig
1219from agent_framework_ag_ui ._orchestrators import DefaultOrchestrator , ExecutionContext
@@ -18,56 +25,53 @@ def server_tool() -> str:
1825 return "server"
1926
2027
21- class DummyAgent :
22- """Minimal agent stub to capture run_stream parameters."""
23-
24- def __init__ (self ) -> None :
25- self .default_options : dict [str , Any ] = {"tools" : [server_tool ], "response_format" : None }
26- self .tools = [server_tool ]
27- self .chat_client = SimpleNamespace (
28- function_invocation_configuration = FunctionInvocationConfiguration (),
29- )
30- self .seen_tools : list [Any ] | None = None
28+ def _create_mock_chat_agent (
29+ tools : list [Any ] | None = None ,
30+ response_format : Any = None ,
31+ capture_tools : list [Any ] | None = None ,
32+ capture_messages : list [Any ] | None = None ,
33+ ) -> ChatAgent :
34+ """Create a ChatAgent with mocked chat client for testing.
35+
36+ Args:
37+ tools: Tools to configure on the agent.
38+ response_format: Response format to configure.
39+ capture_tools: If provided, tools passed to run_stream will be appended here.
40+ capture_messages: If provided, messages passed to run_stream will be appended here.
41+ """
42+ mock_chat_client = MagicMock (spec = BaseChatClient )
43+ mock_chat_client .function_invocation_configuration = FunctionInvocationConfiguration ()
44+
45+ agent = ChatAgent (
46+ chat_client = mock_chat_client ,
47+ tools = tools or [server_tool ],
48+ response_format = response_format ,
49+ )
3150
32- async def run_stream (
33- self ,
51+ # Create a mock run_stream that captures parameters and yields a simple response
52+ async def mock_run_stream (
3453 messages : list [Any ],
3554 * ,
36- thread : Any ,
55+ thread : Any = None ,
3756 tools : list [Any ] | None = None ,
3857 ** kwargs : Any ,
3958 ) -> AsyncGenerator [AgentResponseUpdate , None ]:
40- self .seen_tools = tools
59+ if capture_tools is not None and tools is not None :
60+ capture_tools .extend (tools )
61+ if capture_messages is not None :
62+ capture_messages .extend (messages )
4163 yield AgentResponseUpdate (contents = [TextContent (text = "ok" )], role = "assistant" )
4264
65+ # Patch the run_stream method
66+ agent .run_stream = mock_run_stream # type: ignore[method-assign]
4367
44- class RecordingAgent :
45- """Agent stub that captures messages passed to run_stream."""
46-
47- def __init__ (self ) -> None :
48- self .chat_options = SimpleNamespace (tools = [], response_format = None )
49- self .tools : list [Any ] = []
50- self .chat_client = SimpleNamespace (
51- function_invocation_configuration = FunctionInvocationConfiguration (),
52- )
53- self .seen_messages : list [Any ] | None = None
54-
55- async def run_stream (
56- self ,
57- messages : list [Any ],
58- * ,
59- thread : Any ,
60- tools : list [Any ] | None = None ,
61- ** kwargs : Any ,
62- ) -> AsyncGenerator [AgentResponseUpdate , None ]:
63- self .seen_messages = messages
64- yield AgentResponseUpdate (contents = [TextContent (text = "ok" )], role = "assistant" )
68+ return agent
6569
6670
6771async def test_default_orchestrator_merges_client_tools () -> None :
6872 """Client tool declarations are merged with server tools before running agent."""
69-
70- agent = DummyAgent ( )
73+ captured_tools : list [ Any ] = []
74+ agent = _create_mock_chat_agent ( tools = [ server_tool ], capture_tools = captured_tools )
7175 orchestrator = DefaultOrchestrator ()
7276
7377 input_data = {
@@ -100,17 +104,16 @@ async def test_default_orchestrator_merges_client_tools() -> None:
100104 async for event in orchestrator .run (context ):
101105 events .append (event )
102106
103- assert agent . seen_tools is not None
104- tool_names = [getattr (tool , "name" , "?" ) for tool in agent . seen_tools ]
107+ assert len ( captured_tools ) > 0
108+ tool_names = [getattr (tool , "name" , "?" ) for tool in captured_tools ]
105109 assert "server_tool" in tool_names
106110 assert "get_weather" in tool_names
107111 assert agent .chat_client .function_invocation_configuration .additional_tools
108112
109113
110114async def test_default_orchestrator_with_camel_case_ids () -> None :
111115 """Client tool is able to extract camelCase IDs."""
112-
113- agent = DummyAgent ()
116+ agent = _create_mock_chat_agent ()
114117 orchestrator = DefaultOrchestrator ()
115118
116119 input_data = {
@@ -143,8 +146,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None:
143146
144147async def test_default_orchestrator_with_snake_case_ids () -> None :
145148 """Client tool is able to extract snake_case IDs."""
146-
147- agent = DummyAgent ()
149+ agent = _create_mock_chat_agent ()
148150 orchestrator = DefaultOrchestrator ()
149151
150152 input_data = {
@@ -177,8 +179,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None:
177179
178180async def test_state_context_injected_when_tool_call_state_mismatch () -> None :
179181 """State context should be injected when current state differs from tool call args."""
180-
181- agent = RecordingAgent ( )
182+ captured_messages : list [ Any ] = []
183+ agent = _create_mock_chat_agent ( tools = [], capture_messages = captured_messages )
182184 orchestrator = DefaultOrchestrator ()
183185
184186 tool_recipe = {"title" : "Salad" , "special_preferences" : []}
@@ -215,9 +217,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
215217 async for _event in orchestrator .run (context ):
216218 pass
217219
218- assert agent . seen_messages is not None
220+ assert len ( captured_messages ) > 0
219221 state_messages = []
220- for msg in agent . seen_messages :
222+ for msg in captured_messages :
221223 role_value = msg .role .value if hasattr (msg .role , "value" ) else str (msg .role )
222224 if role_value != "system" :
223225 continue
@@ -230,8 +232,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
230232
231233async def test_state_context_not_injected_when_tool_call_matches_state () -> None :
232234 """State context should be skipped when tool call args match current state."""
233-
234- agent = RecordingAgent ( )
235+ captured_messages : list [ Any ] = []
236+ agent = _create_mock_chat_agent ( tools = [], capture_messages = captured_messages )
235237 orchestrator = DefaultOrchestrator ()
236238
237239 input_data = {
@@ -264,9 +266,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
264266 async for _event in orchestrator .run (context ):
265267 pass
266268
267- assert agent . seen_messages is not None
269+ assert len ( captured_messages ) > 0
268270 state_messages = []
269- for msg in agent . seen_messages :
271+ for msg in captured_messages :
270272 role_value = msg .role .value if hasattr (msg .role , "value" ) else str (msg .role )
271273 if role_value != "system" :
272274 continue
0 commit comments