Skip to content

Commit 8c6a2f4

Browse files
committed
Merge upstream and extend cleanup to tool call tasks
Resolved conflict by applying the same async cleanup pattern to both guardrail and tool call tasks. Both methods now properly cancel and await tasks using asyncio.gather with return_exceptions=True. Added comprehensive test coverage for tool call task cleanup: - Test cancellation and awaiting of tool call tasks - Test exception handling during cleanup - Test that _cleanup() awaits both task types This ensures consistent resource cleanup across all background tasks in realtime sessions.
2 parents 3077b3a + 8c4d4d0 commit 8c6a2f4

File tree

7 files changed

+313
-21
lines changed

7 files changed

+313
-21
lines changed

examples/realtime/app/agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
from agents import function_tool
24
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
35
from agents.realtime import RealtimeAgent, realtime_handoff
@@ -13,20 +15,26 @@
1315
name_override="faq_lookup_tool", description_override="Lookup frequently asked questions."
1416
)
1517
async def faq_lookup_tool(question: str) -> str:
16-
if "bag" in question or "baggage" in question:
18+
print("faq_lookup_tool called with question:", question)
19+
20+
# Simulate a slow API call
21+
await asyncio.sleep(3)
22+
23+
q = question.lower()
24+
if "wifi" in q or "wi-fi" in q:
25+
return "We have free wifi on the plane, join Airline-Wifi"
26+
elif "bag" in q or "baggage" in q:
1727
return (
1828
"You are allowed to bring one bag on the plane. "
1929
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
2030
)
21-
elif "seats" in question or "plane" in question:
31+
elif "seats" in q or "plane" in q:
2232
return (
2333
"There are 120 seats on the plane. "
2434
"There are 22 business class seats and 98 economy seats. "
2535
"Exit rows are rows 4 and 16. "
2636
"Rows 5-8 are Economy Plus, with extra legroom. "
2737
)
28-
elif "wifi" in question:
29-
return "We have free wifi on the plane, join Airline-Wifi"
3038
return "I'm sorry, I don't know the answer to that question."
3139

3240

examples/realtime/app/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ async def connect(self, websocket: WebSocket, session_id: str):
4747

4848
agent = get_starting_agent()
4949
runner = RealtimeRunner(agent)
50+
# If you want to customize the runner behavior, you can pass options:
51+
# runner_config = RealtimeRunConfig(async_tool_calls=False)
52+
# runner = RealtimeRunner(agent, config=runner_config)
5053
model_config: RealtimeModelConfig = {
5154
"initial_model_settings": {
5255
"turn_detection": {

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ async def handle_stream(
150150
)
151151

152152
if reasoning_content and state.reasoning_content_index_and_output:
153+
# Ensure summary list has at least one element
154+
if not state.reasoning_content_index_and_output[1].summary:
155+
state.reasoning_content_index_and_output[1].summary = [
156+
Summary(text="", type="summary_text")
157+
]
158+
153159
yield ResponseReasoningSummaryTextDeltaEvent(
154160
delta=reasoning_content,
155161
item_id=FAKE_RESPONSES_ID,
@@ -201,7 +207,7 @@ async def handle_stream(
201207
)
202208

203209
# Create a new summary with updated text
204-
if state.reasoning_content_index_and_output[1].content is None:
210+
if not state.reasoning_content_index_and_output[1].content:
205211
state.reasoning_content_index_and_output[1].content = [
206212
Content(text="", type="reasoning_text")
207213
]

src/agents/realtime/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ class RealtimeRunConfig(TypedDict):
184184
tracing_disabled: NotRequired[bool]
185185
"""Whether tracing is disabled for this run."""
186186

187+
async_tool_calls: NotRequired[bool]
188+
"""Whether function tool calls should run asynchronously. Defaults to True."""
189+
187190
# TODO (rm) Add history audio storage config
188191

189192

src/agents/realtime/session.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
}
113113
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
114114
self._closed = False
115-
self._stored_exception: Exception | None = None
115+
self._stored_exception: BaseException | None = None
116116

117117
# Guardrails state tracking
118118
self._interrupted_response_ids: set[str] = set()
@@ -123,6 +123,8 @@ def __init__(
123123
)
124124

125125
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
126+
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
127+
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))
126128

127129
@property
128130
def model(self) -> RealtimeModel:
@@ -216,7 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
216218
if event.type == "error":
217219
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
218220
elif event.type == "function_call":
219-
await self._handle_tool_call(event)
221+
agent_snapshot = self._current_agent
222+
if self._async_tool_calls:
223+
self._enqueue_tool_call_task(event, agent_snapshot)
224+
else:
225+
await self._handle_tool_call(event, agent_snapshot=agent_snapshot)
220226
elif event.type == "audio":
221227
await self._put_event(
222228
RealtimeAudio(
@@ -384,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
384390
"""Put an event into the queue."""
385391
await self._event_queue.put(event)
386392

387-
async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
393+
async def _handle_tool_call(
394+
self,
395+
event: RealtimeModelToolCallEvent,
396+
*,
397+
agent_snapshot: RealtimeAgent | None = None,
398+
) -> None:
388399
"""Handle a tool call event."""
400+
agent = agent_snapshot or self._current_agent
389401
tools, handoffs = await asyncio.gather(
390-
self._current_agent.get_all_tools(self._context_wrapper),
391-
self._get_handoffs(self._current_agent, self._context_wrapper),
402+
agent.get_all_tools(self._context_wrapper),
403+
self._get_handoffs(agent, self._context_wrapper),
392404
)
393405
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
394406
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
@@ -398,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
398410
RealtimeToolStart(
399411
info=self._event_info,
400412
tool=function_map[event.name],
401-
agent=self._current_agent,
413+
agent=agent,
402414
)
403415
)
404416

@@ -423,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
423435
info=self._event_info,
424436
tool=func_tool,
425437
output=result,
426-
agent=self._current_agent,
438+
agent=agent,
427439
)
428440
)
429441
elif event.name in handoff_map:
@@ -444,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
444456
)
445457

446458
# Store previous agent for event
447-
previous_agent = self._current_agent
459+
previous_agent = agent
448460

449461
# Update current agent
450462
self._current_agent = result
@@ -762,10 +774,59 @@ async def _cleanup_guardrail_tasks(self) -> None:
762774

763775
self._guardrail_tasks.clear()
764776

777+
def _enqueue_tool_call_task(
778+
self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent
779+
) -> None:
780+
"""Run tool calls in the background to avoid blocking realtime transport."""
781+
task = asyncio.create_task(self._handle_tool_call(event, agent_snapshot=agent_snapshot))
782+
self._tool_call_tasks.add(task)
783+
task.add_done_callback(self._on_tool_call_task_done)
784+
785+
def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
786+
self._tool_call_tasks.discard(task)
787+
788+
if task.cancelled():
789+
return
790+
791+
exception = task.exception()
792+
if exception is None:
793+
return
794+
795+
logger.exception("Realtime tool call task failed", exc_info=exception)
796+
797+
if self._stored_exception is None:
798+
self._stored_exception = exception
799+
800+
asyncio.create_task(
801+
self._put_event(
802+
RealtimeError(
803+
info=self._event_info,
804+
error={"message": f"Tool call task failed: {exception}"},
805+
)
806+
)
807+
)
808+
809+
async def _cleanup_tool_call_tasks(self) -> None:
810+
"""Cancel all pending tool call tasks and wait for them to complete.
811+
812+
This ensures that any exceptions raised by the tasks are properly handled
813+
and prevents warnings about unhandled task exceptions.
814+
"""
815+
for task in self._tool_call_tasks:
816+
if not task.done():
817+
task.cancel()
818+
819+
# Wait for all tasks to complete and collect any exceptions
820+
if self._tool_call_tasks:
821+
await asyncio.gather(*self._tool_call_tasks, return_exceptions=True)
822+
823+
self._tool_call_tasks.clear()
824+
765825
async def _cleanup(self) -> None:
766826
"""Clean up all resources and mark session as closed."""
767827
# Cancel and cleanup guardrail tasks
768828
await self._cleanup_guardrail_tasks()
829+
await self._cleanup_tool_call_tasks()
769830

770831
# Remove ourselves as a listener
771832
self._model.remove_listener(self)

0 commit comments

Comments
 (0)