|
1 |
| -from typing import Any |
| 1 | +from collections import defaultdict |
| 2 | +from typing import Any, Optional |
2 | 3 |
|
3 | 4 | import pytest
|
4 | 5 |
|
5 |
| -# Types from the openai library used by the SDK |
6 |
| -from openai.types.responses import ResponseFunctionToolCall, ResponseOutputMessage |
7 |
| - |
8 |
| -# Core SDK Imports |
9 | 6 | from agents.agent import Agent
|
10 |
| -from agents.items import ModelResponse |
| 7 | +from agents.items import ModelResponse, TResponseInputItem |
11 | 8 | from agents.lifecycle import AgentHooks
|
12 |
| -from agents.models.interface import Model |
13 | 9 | from agents.run import Runner
|
14 |
| -from agents.tool import Tool, function_tool |
15 |
| -from agents.usage import InputTokensDetails, OutputTokensDetails, Usage |
16 |
| - |
17 |
| - |
18 |
| -# --- 1. Spy Hook Implementation --- |
19 |
| -class LoggingAgentHooks(AgentHooks[Any]): |
20 |
| - def __init__(self): |
21 |
| - super().__init__() |
22 |
| - self.called_hooks: list[str] = [] |
23 |
| - |
24 |
| - # Spy on the NEW hooks |
25 |
| - async def on_llm_start(self, *args, **kwargs): |
26 |
| - self.called_hooks.append("on_llm_start") |
27 |
| - |
28 |
| - async def on_llm_end(self, *args, **kwargs): |
29 |
| - self.called_hooks.append("on_llm_end") |
| 10 | +from agents.run_context import RunContextWrapper, TContext |
| 11 | +from agents.tool import Tool |
30 | 12 |
|
31 |
| - # Spy on EXISTING hooks to serve as landmarks for sequence verification |
32 |
| - async def on_start(self, *args, **kwargs): |
33 |
| - self.called_hooks.append("on_start") |
| 13 | +from .fake_model import FakeModel |
| 14 | +from .test_responses import ( |
| 15 | + get_function_tool, |
| 16 | + get_text_message, |
| 17 | +) |
34 | 18 |
|
35 |
| - async def on_end(self, *args, **kwargs): |
36 |
| - self.called_hooks.append("on_end") |
37 |
| - |
38 |
| - async def on_tool_start(self, *args, **kwargs): |
39 |
| - self.called_hooks.append("on_tool_start") |
40 |
| - |
41 |
| - async def on_tool_end(self, *args, **kwargs): |
42 |
| - self.called_hooks.append("on_tool_end") |
43 |
| - |
44 |
| - |
45 |
| -# --- 2. Mock Model and Tools --- |
46 |
| -class MockModel(Model): |
47 |
| - """A mock model that can be configured to either return a chat message or a tool call.""" |
48 | 19 |
|
| 20 | +class AgentHooksForTests(AgentHooks): |
49 | 21 | def __init__(self):
|
50 |
| - self._call_count = 0 |
51 |
| - self._should_call_tool = False |
52 |
| - self._tool_to_call: Tool | None = None |
53 |
| - |
54 |
| - def configure_for_tool_call(self, tool: Tool): |
55 |
| - self._should_call_tool = True |
56 |
| - self._tool_to_call = tool |
57 |
| - |
58 |
| - def configure_for_chat(self): |
59 |
| - self._should_call_tool = False |
60 |
| - self._tool_to_call = None |
61 |
| - |
62 |
| - async def get_response(self, *args, **kwargs) -> ModelResponse: |
63 |
| - self._call_count += 1 |
64 |
| - response_items: list[Any] = [] |
65 |
| - |
66 |
| - if self._should_call_tool and self._call_count == 1: |
67 |
| - response_items.append( |
68 |
| - ResponseFunctionToolCall( |
69 |
| - name=self._tool_to_call.name, |
70 |
| - arguments="{}", |
71 |
| - call_id="call123", |
72 |
| - type="function_call", |
73 |
| - ) |
74 |
| - ) |
75 |
| - else: |
76 |
| - response_items.append( |
77 |
| - ResponseOutputMessage( |
78 |
| - id="msg1", |
79 |
| - content=[{"type": "output_text", "text": "Mock response", "annotations": []}], |
80 |
| - role="assistant", |
81 |
| - status="completed", |
82 |
| - type="message", |
83 |
| - ) |
84 |
| - ) |
85 |
| - |
86 |
| - mock_usage = Usage( |
87 |
| - requests=1, |
88 |
| - input_tokens=10, |
89 |
| - output_tokens=10, |
90 |
| - total_tokens=20, |
91 |
| - input_tokens_details=InputTokensDetails(cached_tokens=0), |
92 |
| - output_tokens_details=OutputTokensDetails(reasoning_tokens=0), |
93 |
| - ) |
94 |
| - return ModelResponse(output=response_items, usage=mock_usage, response_id="resp123") |
95 |
| - |
96 |
| - async def stream_response(self, *args, **kwargs): |
97 |
| - final_response = await self.get_response(*args, **kwargs) |
98 |
| - from openai.types.responses import ResponseCompletedEvent |
99 |
| - |
100 |
| - class MockSDKResponse: |
101 |
| - def __init__(self, id, output, usage): |
102 |
| - self.id, self.output, self.usage = id, output, usage |
103 |
| - |
104 |
| - yield ResponseCompletedEvent( |
105 |
| - response=MockSDKResponse( |
106 |
| - final_response.response_id, final_response.output, final_response.usage |
107 |
| - ), |
108 |
| - type="response_completed", |
109 |
| - ) |
110 |
| - |
111 |
| - |
112 |
| -@function_tool |
113 |
| -def mock_tool(a: int, b: int) -> int: |
114 |
| - """A mock tool for testing tool call hooks.""" |
115 |
| - return a + b |
116 |
| - |
117 |
| - |
118 |
| -# --- 3. Pytest Fixtures for Test Setup --- |
119 |
| -@pytest.fixture |
120 |
| -def logging_hooks() -> LoggingAgentHooks: |
121 |
| - """Provides a fresh instance of LoggingAgentHooks for each test.""" |
122 |
| - return LoggingAgentHooks() |
123 |
| - |
124 |
| - |
125 |
| -@pytest.fixture |
126 |
| -def chat_agent(logging_hooks: LoggingAgentHooks) -> Agent: |
127 |
| - """Provides an agent configured for a simple chat interaction.""" |
128 |
| - mock_model = MockModel() |
129 |
| - mock_model.configure_for_chat() |
130 |
| - return Agent( |
131 |
| - name="ChatAgent", instructions="Test agent for chat.", model=mock_model, hooks=logging_hooks |
132 |
| - ) |
133 |
| - |
134 |
| - |
135 |
| -@pytest.fixture |
136 |
| -def tool_agent(logging_hooks: LoggingAgentHooks) -> Agent: |
137 |
| - """Provides an agent configured to use a tool.""" |
138 |
| - mock_model = MockModel() |
139 |
| - mock_model.configure_for_tool_call(mock_tool) |
140 |
| - return Agent( |
141 |
| - name="ToolAgent", |
142 |
| - instructions="Test agent for tools.", |
143 |
| - model=mock_model, |
144 |
| - hooks=logging_hooks, |
145 |
| - tools=[mock_tool], |
146 |
| - ) |
147 |
| - |
148 |
| - |
149 |
| -# --- 4. Test Cases Focused on New Hooks --- |
150 |
| -@pytest.mark.asyncio |
151 |
| -async def test_llm_hooks_fire_in_chat_scenario(chat_agent: Agent, logging_hooks: LoggingAgentHooks): |
152 |
| - """ |
153 |
| - Tests that on_llm_start and on_llm_end fire correctly for a chat-only turn. |
154 |
| - """ |
155 |
| - await Runner.run(chat_agent, "Hello") |
156 |
| - |
157 |
| - sequence = logging_hooks.called_hooks |
158 |
| - |
159 |
| - expected_sequence = [ |
160 |
| - "on_start", |
161 |
| - "on_llm_start", |
162 |
| - "on_llm_end", |
163 |
| - "on_end", |
164 |
| - ] |
165 |
| - assert sequence == expected_sequence |
166 |
| - |
167 |
| - |
168 |
| -@pytest.mark.asyncio |
169 |
| -async def test_llm_hooks_wrap_tool_hooks_in_tool_scenario( |
170 |
| - tool_agent: Agent, logging_hooks: LoggingAgentHooks |
171 |
| -): |
172 |
| - """ |
173 |
| - Tests that on_llm_start and on_llm_end wrap the tool execution cycle. |
174 |
| - """ |
175 |
| - await Runner.run(tool_agent, "Use your tool") |
176 |
| - |
177 |
| - sequence = logging_hooks.called_hooks |
178 |
| - |
179 |
| - expected_sequence = [ |
180 |
| - "on_start", |
181 |
| - "on_llm_start", |
182 |
| - "on_llm_end", |
183 |
| - "on_tool_start", |
184 |
| - "on_tool_end", |
185 |
| - "on_llm_start", |
186 |
| - "on_llm_end", |
187 |
| - "on_end", |
188 |
| - ] |
189 |
| - assert sequence == expected_sequence |
190 |
| - |
191 |
| - |
| 22 | + self.events: dict[str, int] = defaultdict(int) |
| 23 | + |
| 24 | + def reset(self): |
| 25 | + self.events.clear() |
| 26 | + |
| 27 | + async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: |
| 28 | + self.events["on_start"] += 1 |
| 29 | + |
| 30 | + async def on_end( |
| 31 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any |
| 32 | + ) -> None: |
| 33 | + self.events["on_end"] += 1 |
| 34 | + |
| 35 | + async def on_handoff( |
| 36 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext] |
| 37 | + ) -> None: |
| 38 | + self.events["on_handoff"] += 1 |
| 39 | + |
| 40 | + async def on_tool_start( |
| 41 | + self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool |
| 42 | + ) -> None: |
| 43 | + self.events["on_tool_start"] += 1 |
| 44 | + |
| 45 | + async def on_tool_end( |
| 46 | + self, |
| 47 | + context: RunContextWrapper[TContext], |
| 48 | + agent: Agent[TContext], |
| 49 | + tool: Tool, |
| 50 | + result: str, |
| 51 | + ) -> None: |
| 52 | + self.events["on_tool_end"] += 1 |
| 53 | + |
| 54 | + # NEW: LLM hooks |
| 55 | + async def on_llm_start( |
| 56 | + self, |
| 57 | + context: RunContextWrapper[TContext], |
| 58 | + agent: Agent[TContext], |
| 59 | + system_prompt: Optional[str], |
| 60 | + input_items: list[TResponseInputItem], |
| 61 | + ) -> None: |
| 62 | + self.events["on_llm_start"] += 1 |
| 63 | + |
| 64 | + async def on_llm_end( |
| 65 | + self, |
| 66 | + ccontext: RunContextWrapper[TContext], |
| 67 | + agent: Agent[TContext], |
| 68 | + response: ModelResponse, |
| 69 | + ) -> None: |
| 70 | + self.events["on_llm_end"] += 1 |
| 71 | + |
| 72 | + |
| 73 | +# Example test using the above hooks: |
192 | 74 | @pytest.mark.asyncio
|
193 |
| -async def test_no_hooks_run_if_hooks_is_none(): |
194 |
| - """ |
195 |
| - Ensures that the agent runs without error when agent.hooks is None. |
196 |
| - """ |
197 |
| - mock_model = MockModel() |
198 |
| - mock_model.configure_for_chat() |
199 |
| - agent_no_hooks = Agent( |
200 |
| - name="NoHooksAgent", instructions="Test agent without hooks.", model=mock_model, hooks=None |
| 75 | +async def test_non_streamed_agent_hooks_with_llm(): |
| 76 | + hooks = AgentHooksForTests() |
| 77 | + model = FakeModel() |
| 78 | + agent = Agent( |
| 79 | + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks |
201 | 80 | )
|
202 |
| - |
203 |
| - try: |
204 |
| - await Runner.run(agent_no_hooks, "Hello") |
205 |
| - except Exception as e: |
206 |
| - pytest.fail(f"Runner.run failed when agent.hooks was None: {e}") |
| 81 | + # Simulate a single LLM call producing an output: |
| 82 | + model.set_next_output([get_text_message("hello")]) |
| 83 | + await Runner.run(agent, input="hello") |
| 84 | + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end |
| 85 | + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1} |
0 commit comments