Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions examples/basic/lifecycle_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import random
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel

from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool
from agents.items import ModelResponse, TResponseInputItem


class ExampleHooks(RunHooks):
Expand All @@ -20,6 +21,22 @@ async def on_agent_start(self, context: RunContextWrapper, agent: Agent) -> None
f"### {self.event_counter}: Agent {agent.name} started. Usage: {self._usage_to_str(context.usage)}"
)

async def on_llm_start(
self,
context: RunContextWrapper,
agent: Agent,
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
self.event_counter += 1
print(f"### {self.event_counter}: LLM started. Usage: {self._usage_to_str(context.usage)}")

async def on_llm_end(
self, context: RunContextWrapper, agent: Agent, response: ModelResponse
) -> None:
self.event_counter += 1
print(f"### {self.event_counter}: LLM ended. Usage: {self._usage_to_str(context.usage)}")

async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: Any) -> None:
self.event_counter += 1
print(
Expand Down Expand Up @@ -109,13 +126,21 @@ async def main() -> None:

Enter a max number: 250
### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
### 2: Tool random_number started. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens
### 3: Tool random_number ended with result 101. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total token
### 4: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens
### 5: Agent Multiply Agent started. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens
### 6: Tool multiply_by_two started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens
### 7: Tool multiply_by_two ended with result 202. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens
### 8: Agent Multiply Agent ended with output number=202. Usage: 4 requests, 714 input tokens, 63 output tokens, 777 total tokens
### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
Done!

"""
63 changes: 46 additions & 17 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,16 @@ async def _run_single_turn_streamed(
)

# Call hook just before the model is invoked, with the correct system_prompt.
if agent.hooks:
await agent.hooks.on_llm_start(
context_wrapper, agent, filtered.instructions, filtered.input
)
await asyncio.gather(
hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input),
(
agent.hooks.on_llm_start(
context_wrapper, agent, filtered.instructions, filtered.input
)
if agent.hooks
else _coro.noop_coroutine()
),
)

# 1. Stream the output events
async for event in model.stream_response(
Expand Down Expand Up @@ -1056,8 +1062,15 @@ async def _run_single_turn_streamed(
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

# Call hook just after the model response is finalized.
if agent.hooks and final_response is not None:
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)
if final_response is not None:
await asyncio.gather(
(
agent.hooks.on_llm_end(context_wrapper, agent, final_response)
if agent.hooks
else _coro.noop_coroutine()
),
hooks.on_llm_end(context_wrapper, agent, final_response),
)

# 2. At this point, the streaming is complete for this turn of the agent loop.
if not final_response:
Expand Down Expand Up @@ -1150,6 +1163,7 @@ async def _run_single_turn(
output_schema,
all_tools,
handoffs,
hooks,
context_wrapper,
run_config,
tool_use_tracker,
Expand Down Expand Up @@ -1345,6 +1359,7 @@ async def _get_new_response(
output_schema: AgentOutputSchemaBase | None,
all_tools: list[Tool],
handoffs: list[Handoff],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
Expand All @@ -1364,14 +1379,21 @@ async def _get_new_response(
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
# If the agent has hooks, we need to call them before and after the LLM call
if agent.hooks:
await agent.hooks.on_llm_start(
context_wrapper,
agent,
filtered.instructions, # Use filtered instructions
filtered.input, # Use filtered input
)

# If we have run hooks, or if the agent has hooks, we need to call them before the LLM call
await asyncio.gather(
hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input),
(
agent.hooks.on_llm_start(
context_wrapper,
agent,
filtered.instructions, # Use filtered instructions
filtered.input, # Use filtered input
)
if agent.hooks
else _coro.noop_coroutine()
),
)

new_response = await model.get_response(
system_instructions=filtered.instructions,
Expand All @@ -1387,12 +1409,19 @@ async def _get_new_response(
conversation_id=conversation_id,
prompt=prompt_config,
)
# If the agent has hooks, we need to call them after the LLM call
if agent.hooks:
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)

context_wrapper.usage.add(new_response.usage)

# If we have run hooks, or if the agent has hooks, we need to call them after the LLM call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this right after L1422's usage update? It shouldn't bring any visible overhead in processing time and can provide better insights for the callback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

This fixes the issue I was seeing when running the lifecycle test. Thanks!

await asyncio.gather(
(
agent.hooks.on_llm_end(context_wrapper, agent, new_response)
if agent.hooks
else _coro.noop_coroutine()
),
hooks.on_llm_end(context_wrapper, agent, new_response),
)

return new_response

@classmethod
Expand Down
223 changes: 223 additions & 0 deletions tests/test_run_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from collections import defaultdict
from typing import Any, Optional

import pytest

from agents.agent import Agent
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
from agents.lifecycle import RunHooks
from agents.models.interface import Model
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
from agents.tool import Tool
from tests.test_agent_llm_hooks import AgentHooksForTests

from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_text_message,
)


class RunHooksForTests(RunHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)

def reset(self):
self.events.clear()

async def on_agent_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
) -> None:
self.events["on_agent_start"] += 1

async def on_agent_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
) -> None:
self.events["on_agent_end"] += 1

async def on_handoff(
self,
context: RunContextWrapper[TContext],
from_agent: Agent[TContext],
to_agent: Agent[TContext],
) -> None:
self.events["on_handoff"] += 1

async def on_tool_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
) -> None:
self.events["on_tool_start"] += 1

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
) -> None:
self.events["on_tool_end"] += 1

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
self.events["on_llm_start"] += 1

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
self.events["on_llm_end"] += 1


# Example test using the above hooks
@pytest.mark.asyncio
async def test_async_run_hooks_with_llm():
hooks = RunHooksForTests()
model = FakeModel()

agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
await Runner.run(agent, input="hello", hooks=hooks)
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
assert hooks.events == {
"on_agent_start": 1,
"on_llm_start": 1,
"on_llm_end": 1,
"on_agent_end": 1,
}


# test_sync_run_hook_with_llm()
def test_sync_run_hook_with_llm():
hooks = RunHooksForTests()
model = FakeModel()
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
Runner.run_sync(agent, input="hello", hooks=hooks)
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
assert hooks.events == {
"on_agent_start": 1,
"on_llm_start": 1,
"on_llm_end": 1,
"on_agent_end": 1,
}


# test_streamed_run_hooks_with_llm():
@pytest.mark.asyncio
async def test_streamed_run_hooks_with_llm():
hooks = RunHooksForTests()
model = FakeModel()
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)

async for event in stream.stream_events():
if event.type == "raw_response_event":
continue
if event.type == "agent_updated_stream_event":
print(f"[EVENT] agent_updated → {event.new_agent.name}")
elif event.type == "run_item_stream_event":
item = event.item
if item.type == "tool_call_item":
print("[EVENT] tool_call_item")
elif item.type == "tool_call_output_item":
print(f"[EVENT] tool_call_output_item → {item.output}")
elif item.type == "message_output_item":
text = ItemHelpers.text_message_output(item)
print(f"[EVENT] message_output_item → {text}")

# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
assert hooks.events == {
"on_agent_start": 1,
"on_llm_start": 1,
"on_llm_end": 1,
"on_agent_end": 1,
}


# test_async_run_hooks_with_agent_hooks_with_llm
@pytest.mark.asyncio
async def test_async_run_hooks_with_agent_hooks_with_llm():
hooks = RunHooksForTests()
agent_hooks = AgentHooksForTests()
model = FakeModel()

agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=agent_hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
await Runner.run(agent, input="hello", hooks=hooks)
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
assert hooks.events == {
"on_agent_start": 1,
"on_llm_start": 1,
"on_llm_end": 1,
"on_agent_end": 1,
}
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert agent_hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}


@pytest.mark.asyncio
async def test_run_hooks_llm_error_non_streaming(monkeypatch):
hooks = RunHooksForTests()
model = FakeModel()
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])

async def boom(*args, **kwargs):
raise RuntimeError("boom")

monkeypatch.setattr(FakeModel, "get_response", boom, raising=True)

with pytest.raises(RuntimeError, match="boom"):
await Runner.run(agent, input="hello", hooks=hooks)

# Current behavior is that hooks will not fire on LLM failure
assert hooks.events["on_agent_start"] == 1
assert hooks.events["on_llm_start"] == 1
assert hooks.events["on_llm_end"] == 0
assert hooks.events["on_agent_end"] == 0


class BoomModel(Model):
async def get_response(self, *a, **k):
raise AssertionError("get_response should not be called in streaming test")

async def stream_response(self, *a, **k):
yield {"foo": "bar"}
raise RuntimeError("stream blew up")


@pytest.mark.asyncio
async def test_streamed_run_hooks_llm_error(monkeypatch):
"""
Verify that when the streaming path raises, we still emit on_llm_start
but do NOT emit on_llm_end (current behavior), and the exception propagates.
"""
hooks = RunHooksForTests()
agent = Agent(name="A", model=BoomModel(), tools=[get_function_tool("f", "res")], handoffs=[])

stream = Runner.run_streamed(agent, input="hello", hooks=hooks)

# Consuming the stream should surface the exception
with pytest.raises(RuntimeError, match="stream blew up"):
async for _ in stream.stream_events():
pass

# Current behavior: success-only on_llm_end; ensure starts fired but ends did not.
assert hooks.events["on_agent_start"] == 1
assert hooks.events["on_llm_start"] == 1
assert hooks.events["on_llm_end"] == 0
assert hooks.events["on_agent_end"] == 0