Skip to content

Commit 0233cc0

Browse files
authored
Fix #1238 by enhancing HandoffInputData and enable passing async functions (#1302)
This pull request resolves #1238
1 parent ef812c5 commit 0233cc0

File tree

9 files changed

+61
-19
lines changed

9 files changed

+61
-19
lines changed

examples/handoffs/message_filter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def spanish_handoff_message_filter(handoff_message_data: HandoffInputData) -> Ha
2424
else handoff_message_data.input_history
2525
)
2626

27+
# or, you can use the HandoffInputData.clone(kwargs) method
2728
return HandoffInputData(
2829
input_history=history,
2930
pre_handoff_items=tuple(handoff_message_data.pre_handoff_items),

examples/handoffs/message_filter_streaming.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def spanish_handoff_message_filter(handoff_message_data: HandoffInputData) -> Ha
2424
else handoff_message_data.input_history
2525
)
2626

27+
# or, you can use the HandoffInputData.clone(kwargs) method
2728
return HandoffInputData(
2829
input_history=history,
2930
pre_handoff_items=tuple(handoff_message_data.pre_handoff_items),

src/agents/_run_impl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ async def execute_handoffs(
774774
else original_input,
775775
pre_handoff_items=tuple(pre_step_items),
776776
new_items=tuple(new_step_items),
777+
run_context=context_wrapper,
777778
)
778779
if not callable(input_filter):
779780
_error_tracing.attach_error_to_span(
@@ -785,6 +786,8 @@ async def execute_handoffs(
785786
)
786787
raise UserError(f"Invalid input filter: {input_filter}")
787788
filtered = input_filter(handoff_input_data)
789+
if inspect.isawaitable(filtered):
790+
filtered = await filtered
788791
if not isinstance(filtered, HandoffInputData):
789792
_error_tracing.attach_error_to_span(
790793
span_handoff,

src/agents/extensions/handoff_filters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def remove_all_tools(handoff_input_data: HandoffInputData) -> HandoffInputData:
2929
input_history=filtered_history,
3030
pre_handoff_items=filtered_pre_handoff_items,
3131
new_items=filtered_new_items,
32+
run_context=handoff_input_data.run_context,
3233
)
3334

3435

src/agents/handoffs.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import json
55
from collections.abc import Awaitable
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, replace as dataclasses_replace
77
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
88

99
from pydantic import TypeAdapter
@@ -49,8 +49,24 @@ class HandoffInputData:
4949
handoff and the tool output message representing the response from the handoff output.
5050
"""
5151

52+
run_context: RunContextWrapper[Any] | None = None
53+
"""
54+
The run context at the time the handoff was invoked.
55+
Note that, since this property was added later on, it's optional for backwards compatibility.
56+
"""
57+
58+
def clone(self, **kwargs: Any) -> HandoffInputData:
59+
"""
60+
Make a copy of the handoff input data, with the given arguments changed. For example, you
61+
could do:
62+
```
63+
new_handoff_input_data = handoff_input_data.clone(new_items=())
64+
```
65+
"""
66+
return dataclasses_replace(self, **kwargs)
67+
5268

53-
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData]
69+
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], MaybeAwaitable[HandoffInputData]]
5470
"""A function that filters the input data passed to the next agent."""
5571

5672

tests/test_agent_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData:
224224
input_history=handoff_input_data.input_history,
225225
pre_handoff_items=(),
226226
new_items=(),
227+
run_context=handoff_input_data.run_context,
227228
)
228229

229230

@@ -262,7 +263,7 @@ async def test_handoff_filters():
262263

263264

264265
@pytest.mark.asyncio
265-
async def test_async_input_filter_fails():
266+
async def test_async_input_filter_supported():
266267
# DO NOT rename this without updating pyproject.toml
267268

268269
model = FakeModel()
@@ -274,7 +275,7 @@ async def test_async_input_filter_fails():
274275
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]:
275276
return agent_1
276277

277-
async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
278+
async def async_input_filter(data: HandoffInputData) -> HandoffInputData:
278279
return data # pragma: no cover
279280

280281
agent_2 = Agent[None](
@@ -287,8 +288,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
287288
input_json_schema={},
288289
on_invoke_handoff=on_invoke_handoff,
289290
agent_name=agent_1.name,
290-
# Purposely ignoring the type error here to simulate invalid input
291-
input_filter=invalid_input_filter, # type: ignore
291+
input_filter=async_input_filter,
292292
)
293293
],
294294
)
@@ -300,8 +300,8 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
300300
]
301301
)
302302

303-
with pytest.raises(UserError):
304-
await Runner.run(agent_2, input="user_message")
303+
result = await Runner.run(agent_2, input="user_message")
304+
assert result.final_output == "last"
305305

306306

307307
@pytest.mark.asyncio

tests/test_agent_runner_streamed.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData:
241241
input_history=handoff_input_data.input_history,
242242
pre_handoff_items=(),
243243
new_items=(),
244+
run_context=handoff_input_data.run_context,
244245
)
245246

246247

@@ -281,7 +282,7 @@ async def test_handoff_filters():
281282

282283

283284
@pytest.mark.asyncio
284-
async def test_async_input_filter_fails():
285+
async def test_async_input_filter_supported():
285286
# DO NOT rename this without updating pyproject.toml
286287

287288
model = FakeModel()
@@ -293,7 +294,7 @@ async def test_async_input_filter_fails():
293294
async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]:
294295
return agent_1
295296

296-
async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
297+
async def async_input_filter(data: HandoffInputData) -> HandoffInputData:
297298
return data # pragma: no cover
298299

299300
agent_2 = Agent[None](
@@ -306,8 +307,7 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
306307
input_json_schema={},
307308
on_invoke_handoff=on_invoke_handoff,
308309
agent_name=agent_1.name,
309-
# Purposely ignoring the type error here to simulate invalid input
310-
input_filter=invalid_input_filter, # type: ignore
310+
input_filter=async_input_filter,
311311
)
312312
],
313313
)
@@ -319,10 +319,9 @@ async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData:
319319
]
320320
)
321321

322-
with pytest.raises(UserError):
323-
result = Runner.run_streamed(agent_2, input="user_message")
324-
async for _ in result.stream_events():
325-
pass
322+
result = Runner.run_streamed(agent_2, input="user_message")
323+
async for _ in result.stream_events():
324+
pass
326325

327326

328327
@pytest.mark.asyncio

tests/test_extension_filters.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
22

3-
from agents import Agent, HandoffInputData
3+
from agents import Agent, HandoffInputData, RunContextWrapper
44
from agents.extensions.handoff_filters import remove_all_tools
55
from agents.items import (
66
HandoffOutputItem,
@@ -78,13 +78,23 @@ def _get_handoff_output_run_item(content: str) -> HandoffOutputItem:
7878

7979

8080
def test_empty_data():
81-
handoff_input_data = HandoffInputData(input_history=(), pre_handoff_items=(), new_items=())
81+
handoff_input_data = HandoffInputData(
82+
input_history=(),
83+
pre_handoff_items=(),
84+
new_items=(),
85+
run_context=RunContextWrapper(context=()),
86+
)
8287
filtered_data = remove_all_tools(handoff_input_data)
8388
assert filtered_data == handoff_input_data
8489

8590

8691
def test_str_historyonly():
87-
handoff_input_data = HandoffInputData(input_history="Hello", pre_handoff_items=(), new_items=())
92+
handoff_input_data = HandoffInputData(
93+
input_history="Hello",
94+
pre_handoff_items=(),
95+
new_items=(),
96+
run_context=RunContextWrapper(context=()),
97+
)
8898
filtered_data = remove_all_tools(handoff_input_data)
8999
assert filtered_data == handoff_input_data
90100

@@ -94,6 +104,7 @@ def test_str_history_and_list():
94104
input_history="Hello",
95105
pre_handoff_items=(),
96106
new_items=(_get_message_output_run_item("Hello"),),
107+
run_context=RunContextWrapper(context=()),
97108
)
98109
filtered_data = remove_all_tools(handoff_input_data)
99110
assert filtered_data == handoff_input_data
@@ -104,6 +115,7 @@ def test_list_history_and_list():
104115
input_history=(_get_message_input_item("Hello"),),
105116
pre_handoff_items=(_get_message_output_run_item("123"),),
106117
new_items=(_get_message_output_run_item("World"),),
118+
run_context=RunContextWrapper(context=()),
107119
)
108120
filtered_data = remove_all_tools(handoff_input_data)
109121
assert filtered_data == handoff_input_data
@@ -121,6 +133,7 @@ def test_removes_tools_from_history():
121133
_get_message_output_run_item("123"),
122134
),
123135
new_items=(_get_message_output_run_item("World"),),
136+
run_context=RunContextWrapper(context=()),
124137
)
125138
filtered_data = remove_all_tools(handoff_input_data)
126139
assert len(filtered_data.input_history) == 2
@@ -136,6 +149,7 @@ def test_removes_tools_from_new_items():
136149
_get_message_output_run_item("Hello"),
137150
_get_tool_output_run_item("World"),
138151
),
152+
run_context=RunContextWrapper(context=()),
139153
)
140154
filtered_data = remove_all_tools(handoff_input_data)
141155
assert len(filtered_data.input_history) == 0
@@ -158,6 +172,7 @@ def test_removes_tools_from_new_items_and_history():
158172
_get_message_output_run_item("Hello"),
159173
_get_tool_output_run_item("World"),
160174
),
175+
run_context=RunContextWrapper(context=()),
161176
)
162177
filtered_data = remove_all_tools(handoff_input_data)
163178
assert len(filtered_data.input_history) == 2
@@ -181,6 +196,7 @@ def test_removes_handoffs_from_history():
181196
_get_tool_output_run_item("World"),
182197
_get_handoff_output_run_item("World"),
183198
),
199+
run_context=RunContextWrapper(context=()),
184200
)
185201
filtered_data = remove_all_tools(handoff_input_data)
186202
assert len(filtered_data.input_history) == 1

tests/test_handoff_tool.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,15 @@ def test_handoff_input_data():
221221
input_history="",
222222
pre_handoff_items=(),
223223
new_items=(),
224+
run_context=RunContextWrapper(context=()),
224225
)
225226
assert get_len(data) == 1
226227

227228
data = HandoffInputData(
228229
input_history=({"role": "user", "content": "foo"},),
229230
pre_handoff_items=(),
230231
new_items=(),
232+
run_context=RunContextWrapper(context=()),
231233
)
232234
assert get_len(data) == 1
233235

@@ -238,6 +240,7 @@ def test_handoff_input_data():
238240
),
239241
pre_handoff_items=(),
240242
new_items=(),
243+
run_context=RunContextWrapper(context=()),
241244
)
242245
assert get_len(data) == 2
243246

@@ -251,6 +254,7 @@ def test_handoff_input_data():
251254
message_item("bar", agent),
252255
message_item("baz", agent),
253256
),
257+
run_context=RunContextWrapper(context=()),
254258
)
255259
assert get_len(data) == 5
256260

@@ -264,6 +268,7 @@ def test_handoff_input_data():
264268
message_item("baz", agent),
265269
message_item("qux", agent),
266270
),
271+
run_context=RunContextWrapper(context=()),
267272
)
268273

269274
assert get_len(data) == 5

0 commit comments

Comments
 (0)