Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/agents/extensions/handoff_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..items import (
HandoffCallItem,
HandoffOutputItem,
ReasoningItem,
RunItem,
ToolCallItem,
ToolCallOutputItem,
Expand Down Expand Up @@ -41,6 +42,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
or isinstance(item, HandoffOutputItem)
or isinstance(item, ToolCallItem)
or isinstance(item, ToolCallOutputItem)
or isinstance(item, ReasoningItem)
):
continue
filtered_items.append(item)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_extension_filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from openai.types.responses.response_reasoning_item import ResponseReasoningItem

from agents import Agent, HandoffInputData, RunContextWrapper
from agents.extensions.handoff_filters import remove_all_tools
from agents.items import (
HandoffOutputItem,
MessageOutputItem,
ReasoningItem,
ToolCallOutputItem,
TResponseInputItem,
)
Expand All @@ -23,6 +25,10 @@ def _get_message_input_item(content: str) -> TResponseInputItem:
}


def _get_reasoning_input_item() -> TResponseInputItem:
return {"id": "rid", "summary": [], "type": "reasoning"}


def _get_function_result_input_item(content: str) -> TResponseInputItem:
return {
"call_id": "1",
Expand Down Expand Up @@ -77,6 +83,12 @@ def _get_handoff_output_run_item(content: str) -> HandoffOutputItem:
)


def _get_reasoning_output_run_item() -> ReasoningItem:
return ReasoningItem(
agent=fake_agent(), raw_item=ResponseReasoningItem(id="rid", summary=[], type="reasoning")
)


def test_empty_data():
handoff_input_data = HandoffInputData(
input_history=(),
Expand Down Expand Up @@ -161,21 +173,24 @@ def test_removes_tools_from_new_items_and_history():
handoff_input_data = HandoffInputData(
input_history=(
_get_message_input_item("Hello1"),
_get_reasoning_input_item(),
_get_function_result_input_item("World"),
_get_message_input_item("Hello2"),
),
pre_handoff_items=(
_get_reasoning_output_run_item(),
_get_message_output_run_item("123"),
_get_tool_output_run_item("456"),
),
new_items=(
_get_reasoning_output_run_item(),
_get_message_output_run_item("Hello"),
_get_tool_output_run_item("World"),
),
run_context=RunContextWrapper(context=()),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 2
assert len(filtered_data.input_history) == 3
assert len(filtered_data.pre_handoff_items) == 1
assert len(filtered_data.new_items) == 1

Expand All @@ -187,11 +202,13 @@ def test_removes_handoffs_from_history():
_get_handoff_input_item("World"),
),
pre_handoff_items=(
_get_reasoning_output_run_item(),
_get_message_output_run_item("Hello"),
_get_tool_output_run_item("World"),
_get_handoff_output_run_item("World"),
),
new_items=(
_get_reasoning_output_run_item(),
_get_message_output_run_item("Hello"),
_get_tool_output_run_item("World"),
_get_handoff_output_run_item("World"),
Expand Down