diff --git a/src/agents/extensions/handoff_filters.py b/src/agents/extensions/handoff_filters.py index 4abe99a45..a4433ae0c 100644 --- a/src/agents/extensions/handoff_filters.py +++ b/src/agents/extensions/handoff_filters.py @@ -4,6 +4,7 @@ from ..items import ( HandoffCallItem, HandoffOutputItem, + ReasoningItem, RunItem, ToolCallItem, ToolCallOutputItem, @@ -41,6 +42,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]: or isinstance(item, HandoffOutputItem) or isinstance(item, ToolCallItem) or isinstance(item, ToolCallOutputItem) + or isinstance(item, ReasoningItem) ): continue filtered_items.append(item) diff --git a/tests/test_extension_filters.py b/tests/test_extension_filters.py index 3c2ba9e4f..11fba51ba 100644 --- a/tests/test_extension_filters.py +++ b/tests/test_extension_filters.py @@ -1,10 +1,12 @@ from openai.types.responses import ResponseOutputMessage, ResponseOutputText +from openai.types.responses.response_reasoning_item import ResponseReasoningItem from agents import Agent, HandoffInputData, RunContextWrapper from agents.extensions.handoff_filters import remove_all_tools from agents.items import ( HandoffOutputItem, MessageOutputItem, + ReasoningItem, ToolCallOutputItem, TResponseInputItem, ) @@ -23,6 +25,10 @@ def _get_message_input_item(content: str) -> TResponseInputItem: } +def _get_reasoning_input_item() -> TResponseInputItem: + return {"id": "rid", "summary": [], "type": "reasoning"} + + def _get_function_result_input_item(content: str) -> TResponseInputItem: return { "call_id": "1", @@ -77,6 +83,12 @@ def _get_handoff_output_run_item(content: str) -> HandoffOutputItem: ) +def _get_reasoning_output_run_item() -> ReasoningItem: + return ReasoningItem( + agent=fake_agent(), raw_item=ResponseReasoningItem(id="rid", summary=[], type="reasoning") + ) + + def test_empty_data(): handoff_input_data = HandoffInputData( input_history=(), @@ -161,21 +173,24 @@ def test_removes_tools_from_new_items_and_history(): handoff_input_data = HandoffInputData( input_history=( _get_message_input_item("Hello1"), + _get_reasoning_input_item(), _get_function_result_input_item("World"), _get_message_input_item("Hello2"), ), pre_handoff_items=( + _get_reasoning_output_run_item(), _get_message_output_run_item("123"), _get_tool_output_run_item("456"), ), new_items=( + _get_reasoning_output_run_item(), _get_message_output_run_item("Hello"), _get_tool_output_run_item("World"), ), run_context=RunContextWrapper(context=()), ) filtered_data = remove_all_tools(handoff_input_data) - assert len(filtered_data.input_history) == 2 + assert len(filtered_data.input_history) == 3 assert len(filtered_data.pre_handoff_items) == 1 assert len(filtered_data.new_items) == 1 @@ -187,11 +202,13 @@ def test_removes_handoffs_from_history(): _get_handoff_input_item("World"), ), pre_handoff_items=( + _get_reasoning_output_run_item(), _get_message_output_run_item("Hello"), _get_tool_output_run_item("World"), _get_handoff_output_run_item("World"), ), new_items=( + _get_reasoning_output_run_item(), _get_message_output_run_item("Hello"), _get_tool_output_run_item("World"), _get_handoff_output_run_item("World"),