Skip to content

Commit 3537b62

Browse files
committed
fix review comment
1 parent baed3fe commit 3537b62

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/agents/run.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import inspect
56
import os
67
import warnings
@@ -751,7 +752,7 @@ def run_sync(
751752
# We intentionally leave the default loop open even if we had to create one above. Session
752753
# instances and other helpers stash loop-bound primitives between calls and expect to find
753754
# the same default loop every time run_sync is invoked on this thread.
754-
return default_loop.run_until_complete(
755+
task = default_loop.create_task(
755756
self.run(
756757
starting_agent,
757758
input,
@@ -765,6 +766,15 @@ def run_sync(
765766
)
766767
)
767768

769+
try:
770+
return default_loop.run_until_complete(task)
771+
except BaseException:
772+
if not task.done():
773+
task.cancel()
774+
with contextlib.suppress(asyncio.CancelledError):
775+
default_loop.run_until_complete(task)
776+
raise
777+
768778
def run_streamed(
769779
self,
770780
starting_agent: Agent[TContext],

tests/test_agent_runner_sync.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from collections.abc import Generator
3+
from typing import Any
34

45
import pytest
56

@@ -72,3 +73,48 @@ async def invoke():
7273
runner.run_sync(Agent(name="test-agent"), "input")
7374

7475
asyncio.run(invoke())
76+
77+
78+
def test_run_sync_cancels_task_when_interrupted(monkeypatch, fresh_event_loop_policy):
79+
runner = AgentRunner()
80+
81+
async def fake_run(self, *_args, **_kwargs):
82+
await asyncio.sleep(3600)
83+
84+
monkeypatch.setattr(AgentRunner, "run", fake_run, raising=False)
85+
86+
test_loop = asyncio.new_event_loop()
87+
fresh_event_loop_policy.set_event_loop(test_loop)
88+
89+
created_tasks: list[asyncio.Task[Any]] = []
90+
original_create_task = test_loop.create_task
91+
92+
def capturing_create_task(coro):
93+
task = original_create_task(coro)
94+
created_tasks.append(task)
95+
return task
96+
97+
original_run_until_complete = test_loop.run_until_complete
98+
call_count = {"value": 0}
99+
100+
def interrupt_once(future):
101+
call_count["value"] += 1
102+
if call_count["value"] == 1:
103+
raise KeyboardInterrupt()
104+
return original_run_until_complete(future)
105+
106+
monkeypatch.setattr(test_loop, "create_task", capturing_create_task)
107+
monkeypatch.setattr(test_loop, "run_until_complete", interrupt_once)
108+
109+
try:
110+
with pytest.raises(KeyboardInterrupt):
111+
runner.run_sync(Agent(name="test-agent"), "input")
112+
113+
assert created_tasks, "Expected run_sync to schedule a task."
114+
assert created_tasks[0].done()
115+
assert created_tasks[0].cancelled()
116+
assert call_count["value"] >= 2
117+
finally:
118+
monkeypatch.undo()
119+
fresh_event_loop_policy.set_event_loop(None)
120+
test_loop.close()

0 commit comments

Comments
 (0)