Skip to content

Commit 525a365

Browse files
committed
feat: add streaming and tests to new hooks
1 parent fc50e93 commit 525a365

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

src/agents/run.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,9 @@ async def _run_single_turn_streamed(
812812

813813
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
814814
input.extend([item.to_input_item() for item in streamed_result.new_items])
815-
815+
# Call hook just before the model is invoked, with the correct system_prompt.
816+
if agent.hooks:
817+
await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input)
816818
# 1. Stream the output events
817819
async for event in model.stream_response(
818820
system_prompt,
@@ -849,6 +851,10 @@ async def _run_single_turn_streamed(
849851

850852
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
851853

854+
# Call hook just after the model response is finalized.
855+
if agent.hooks:
856+
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)
857+
852858
# 2. At this point, the streaming is complete for this turn of the agent loop.
853859
if not final_response:
854860
raise ModelBehaviorError("Model did not produce a final response!")

tests/test_agent_llm_hooks.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from agents.agent import Agent
7-
from agents.items import ModelResponse, TResponseInputItem
7+
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
88
from agents.lifecycle import AgentHooks
99
from agents.run import Runner
1010
from agents.run_context import RunContextWrapper, TContext
@@ -63,7 +63,7 @@ async def on_llm_start(
6363

6464
async def on_llm_end(
6565
self,
66-
ccontext: RunContextWrapper[TContext],
66+
context: RunContextWrapper[TContext],
6767
agent: Agent[TContext],
6868
response: ModelResponse,
6969
) -> None:
@@ -72,7 +72,7 @@ async def on_llm_end(
7272

7373
# Example test using the above hooks:
7474
@pytest.mark.asyncio
75-
async def test_non_streamed_agent_hooks_with_llm():
75+
async def test_async_agent_hooks_with_llm():
7676
hooks = AgentHooksForTests()
7777
model = FakeModel()
7878
agent = Agent(
@@ -83,3 +83,48 @@ async def test_non_streamed_agent_hooks_with_llm():
8383
await Runner.run(agent, input="hello")
8484
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
8585
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
86+
87+
88+
# test_sync_agent_hook_with_llm()
89+
def test_sync_agent_hook_with_llm():
90+
hooks = AgentHooksForTests()
91+
model = FakeModel()
92+
agent = Agent(
93+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
94+
)
95+
# Simulate a single LLM call producing an output:
96+
model.set_next_output([get_text_message("hello")])
97+
Runner.run_sync(agent, input="hello")
98+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
99+
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
100+
101+
102+
# test_streamed_agent_hooks_with_llm():
103+
@pytest.mark.asyncio
104+
async def test_streamed_agent_hooks_with_llm():
105+
hooks = AgentHooksForTests()
106+
model = FakeModel()
107+
agent = Agent(
108+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
109+
)
110+
# Simulate a single LLM call producing an output:
111+
model.set_next_output([get_text_message("hello")])
112+
stream = Runner.run_streamed(agent, input="hello")
113+
114+
async for event in stream.stream_events():
115+
if event.type == "raw_response_event":
116+
continue
117+
if event.type == "agent_updated_stream_event":
118+
print(f"[EVENT] agent_updated → {event.new_agent.name}")
119+
elif event.type == "run_item_stream_event":
120+
item = event.item
121+
if item.type == "tool_call_item":
122+
print("[EVENT] tool_call_item")
123+
elif item.type == "tool_call_output_item":
124+
print(f"[EVENT] tool_call_output_item → {item.output}")
125+
elif item.type == "message_output_item":
126+
text = ItemHelpers.text_message_output(item)
127+
print(f"[EVENT] message_output_item → {text}")
128+
129+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
130+
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}

0 commit comments

Comments
 (0)