|
1 | | -from typing import Any |
| 1 | +from collections import defaultdict |
| 2 | +from typing import Any, Optional |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 |
|
5 | | -# Types from the openai library used by the SDK |
6 | | -from openai.types.responses import ResponseFunctionToolCall, ResponseOutputMessage |
7 | | - |
8 | | -# Core SDK Imports |
9 | 6 | from agents.agent import Agent |
10 | | -from agents.items import ModelResponse |
| 7 | +from agents.items import ModelResponse, TResponseInputItem |
11 | 8 | from agents.lifecycle import AgentHooks |
12 | | -from agents.models.interface import Model |
13 | 9 | from agents.run import Runner |
14 | | -from agents.tool import Tool, function_tool |
15 | | -from agents.usage import InputTokensDetails, OutputTokensDetails, Usage |
16 | | - |
17 | | - |
18 | | -# --- 1. Spy Hook Implementation --- |
19 | | -class LoggingAgentHooks(AgentHooks[Any]): |
20 | | - def __init__(self): |
21 | | - super().__init__() |
22 | | - self.called_hooks: list[str] = [] |
23 | | - |
24 | | - # Spy on the NEW hooks |
25 | | - async def on_llm_start(self, *args, **kwargs): |
26 | | - self.called_hooks.append("on_llm_start") |
27 | | - |
28 | | - async def on_llm_end(self, *args, **kwargs): |
29 | | - self.called_hooks.append("on_llm_end") |
| 10 | +from agents.run_context import RunContextWrapper, TContext |
| 11 | +from agents.tool import Tool |
30 | 12 |
|
31 | | - # Spy on EXISTING hooks to serve as landmarks for sequence verification |
32 | | - async def on_start(self, *args, **kwargs): |
33 | | - self.called_hooks.append("on_start") |
| 13 | +from .fake_model import FakeModel |
| 14 | +from .test_responses import ( |
| 15 | + get_function_tool, |
| 16 | + get_text_message, |
| 17 | +) |
34 | 18 |
|
35 | | - async def on_end(self, *args, **kwargs): |
36 | | - self.called_hooks.append("on_end") |
37 | | - |
38 | | - async def on_tool_start(self, *args, **kwargs): |
39 | | - self.called_hooks.append("on_tool_start") |
40 | | - |
41 | | - async def on_tool_end(self, *args, **kwargs): |
42 | | - self.called_hooks.append("on_tool_end") |
43 | | - |
44 | | - |
45 | | -# --- 2. Mock Model and Tools --- |
46 | | -class MockModel(Model): |
47 | | - """A mock model that can be configured to either return a chat message or a tool call.""" |
48 | 19 |
|
| 20 | +class AgentHooksForTests(AgentHooks): |
49 | 21 | def __init__(self): |
50 | | - self._call_count = 0 |
51 | | - self._should_call_tool = False |
52 | | - self._tool_to_call: Tool | None = None |
53 | | - |
54 | | - def configure_for_tool_call(self, tool: Tool): |
55 | | - self._should_call_tool = True |
56 | | - self._tool_to_call = tool |
57 | | - |
58 | | - def configure_for_chat(self): |
59 | | - self._should_call_tool = False |
60 | | - self._tool_to_call = None |
61 | | - |
62 | | - async def get_response(self, *args, **kwargs) -> ModelResponse: |
63 | | - self._call_count += 1 |
64 | | - response_items: list[Any] = [] |
65 | | - |
66 | | - if self._should_call_tool and self._call_count == 1: |
67 | | - response_items.append( |
68 | | - ResponseFunctionToolCall( |
69 | | - name=self._tool_to_call.name, |
70 | | - arguments="{}", |
71 | | - call_id="call123", |
72 | | - type="function_call", |
73 | | - ) |
74 | | - ) |
75 | | - else: |
76 | | - response_items.append( |
77 | | - ResponseOutputMessage( |
78 | | - id="msg1", |
79 | | - content=[{"type": "output_text", "text": "Mock response", "annotations": []}], |
80 | | - role="assistant", |
81 | | - status="completed", |
82 | | - type="message", |
83 | | - ) |
84 | | - ) |
85 | | - |
86 | | - mock_usage = Usage( |
87 | | - requests=1, |
88 | | - input_tokens=10, |
89 | | - output_tokens=10, |
90 | | - total_tokens=20, |
91 | | - input_tokens_details=InputTokensDetails(cached_tokens=0), |
92 | | - output_tokens_details=OutputTokensDetails(reasoning_tokens=0), |
93 | | - ) |
94 | | - return ModelResponse(output=response_items, usage=mock_usage, response_id="resp123") |
95 | | - |
96 | | - async def stream_response(self, *args, **kwargs): |
97 | | - final_response = await self.get_response(*args, **kwargs) |
98 | | - from openai.types.responses import ResponseCompletedEvent |
99 | | - |
100 | | - class MockSDKResponse: |
101 | | - def __init__(self, id, output, usage): |
102 | | - self.id, self.output, self.usage = id, output, usage |
103 | | - |
104 | | - yield ResponseCompletedEvent( |
105 | | - response=MockSDKResponse( |
106 | | - final_response.response_id, final_response.output, final_response.usage |
107 | | - ), |
108 | | - type="response_completed", |
109 | | - ) |
110 | | - |
111 | | - |
112 | | -@function_tool |
113 | | -def mock_tool(a: int, b: int) -> int: |
114 | | - """A mock tool for testing tool call hooks.""" |
115 | | - return a + b |
116 | | - |
117 | | - |
118 | | -# --- 3. Pytest Fixtures for Test Setup --- |
119 | | -@pytest.fixture |
120 | | -def logging_hooks() -> LoggingAgentHooks: |
121 | | - """Provides a fresh instance of LoggingAgentHooks for each test.""" |
122 | | - return LoggingAgentHooks() |
123 | | - |
124 | | - |
125 | | -@pytest.fixture |
126 | | -def chat_agent(logging_hooks: LoggingAgentHooks) -> Agent: |
127 | | - """Provides an agent configured for a simple chat interaction.""" |
128 | | - mock_model = MockModel() |
129 | | - mock_model.configure_for_chat() |
130 | | - return Agent( |
131 | | - name="ChatAgent", instructions="Test agent for chat.", model=mock_model, hooks=logging_hooks |
132 | | - ) |
133 | | - |
134 | | - |
135 | | -@pytest.fixture |
136 | | -def tool_agent(logging_hooks: LoggingAgentHooks) -> Agent: |
137 | | - """Provides an agent configured to use a tool.""" |
138 | | - mock_model = MockModel() |
139 | | - mock_model.configure_for_tool_call(mock_tool) |
140 | | - return Agent( |
141 | | - name="ToolAgent", |
142 | | - instructions="Test agent for tools.", |
143 | | - model=mock_model, |
144 | | - hooks=logging_hooks, |
145 | | - tools=[mock_tool], |
146 | | - ) |
147 | | - |
148 | | - |
149 | | -# --- 4. Test Cases Focused on New Hooks --- |
150 | | -@pytest.mark.asyncio |
151 | | -async def test_llm_hooks_fire_in_chat_scenario(chat_agent: Agent, logging_hooks: LoggingAgentHooks): |
152 | | - """ |
153 | | - Tests that on_llm_start and on_llm_end fire correctly for a chat-only turn. |
154 | | - """ |
155 | | - await Runner.run(chat_agent, "Hello") |
156 | | - |
157 | | - sequence = logging_hooks.called_hooks |
158 | | - |
159 | | - expected_sequence = [ |
160 | | - "on_start", |
161 | | - "on_llm_start", |
162 | | - "on_llm_end", |
163 | | - "on_end", |
164 | | - ] |
165 | | - assert sequence == expected_sequence |
166 | | - |
167 | | - |
168 | | -@pytest.mark.asyncio |
169 | | -async def test_llm_hooks_wrap_tool_hooks_in_tool_scenario( |
170 | | - tool_agent: Agent, logging_hooks: LoggingAgentHooks |
171 | | -): |
172 | | - """ |
173 | | - Tests that on_llm_start and on_llm_end wrap the tool execution cycle. |
174 | | - """ |
175 | | - await Runner.run(tool_agent, "Use your tool") |
176 | | - |
177 | | - sequence = logging_hooks.called_hooks |
178 | | - |
179 | | - expected_sequence = [ |
180 | | - "on_start", |
181 | | - "on_llm_start", |
182 | | - "on_llm_end", |
183 | | - "on_tool_start", |
184 | | - "on_tool_end", |
185 | | - "on_llm_start", |
186 | | - "on_llm_end", |
187 | | - "on_end", |
188 | | - ] |
189 | | - assert sequence == expected_sequence |
190 | | - |
191 | | - |
| 22 | + self.events: dict[str, int] = defaultdict(int) |
| 23 | + |
| 24 | + def reset(self): |
| 25 | + self.events.clear() |
| 26 | + |
| 27 | + async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: |
| 28 | + self.events["on_start"] += 1 |
| 29 | + |
| 30 | + async def on_end( |
| 31 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any |
| 32 | + ) -> None: |
| 33 | + self.events["on_end"] += 1 |
| 34 | + |
| 35 | + async def on_handoff( |
| 36 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext] |
| 37 | + ) -> None: |
| 38 | + self.events["on_handoff"] += 1 |
| 39 | + |
| 40 | + async def on_tool_start( |
| 41 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool |
| 42 | + ) -> None: |
| 43 | + self.events["on_tool_start"] += 1 |
| 44 | + |
| 45 | + async def on_tool_end( |
| 46 | + self, |
| 47 | + context: RunContextWrapper[TContext], |
| 48 | + agent: Agent[TContext], |
| 49 | + tool: Tool, |
| 50 | + result: str, |
| 51 | + ) -> None: |
| 52 | + self.events["on_tool_end"] += 1 |
| 53 | + |
| 54 | + # NEW: LLM hooks |
| 55 | + async def on_llm_start( |
| 56 | + self, |
| 57 | + context: RunContextWrapper[TContext], |
| 58 | + agent: Agent[TContext], |
| 59 | + system_prompt: Optional[str], |
| 60 | + input_items: list[TResponseInputItem], |
| 61 | + ) -> None: |
| 62 | + self.events["on_llm_start"] += 1 |
| 63 | + |
| 64 | + async def on_llm_end( |
| 65 | + self, |
| 66 | + ccontext: RunContextWrapper[TContext], |
| 67 | + agent: Agent[TContext], |
| 68 | + response: ModelResponse, |
| 69 | + ) -> None: |
| 70 | + self.events["on_llm_end"] += 1 |
| 71 | + |
| 72 | + |
| 73 | +# Example test using the above hooks: |
192 | 74 | @pytest.mark.asyncio |
193 | | -async def test_no_hooks_run_if_hooks_is_none(): |
194 | | - """ |
195 | | - Ensures that the agent runs without error when agent.hooks is None. |
196 | | - """ |
197 | | - mock_model = MockModel() |
198 | | - mock_model.configure_for_chat() |
199 | | - agent_no_hooks = Agent( |
200 | | - name="NoHooksAgent", instructions="Test agent without hooks.", model=mock_model, hooks=None |
| 75 | +async def test_non_streamed_agent_hooks_with_llm(): |
| 76 | + hooks = AgentHooksForTests() |
| 77 | + model = FakeModel() |
| 78 | + agent = Agent( |
| 79 | + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks |
201 | 80 | ) |
202 | | - |
203 | | - try: |
204 | | - await Runner.run(agent_no_hooks, "Hello") |
205 | | - except Exception as e: |
206 | | - pytest.fail(f"Runner.run failed when agent.hooks was None: {e}") |
| 81 | + # Simulate a single LLM call producing an output: |
| 82 | + model.set_next_output([get_text_message("hello")]) |
| 83 | + await Runner.run(agent, input="hello") |
| 84 | + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end |
| 85 | + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1} |
0 commit comments