Skip to content

Commit 5f0c056

Browse files
committed
Fix HITL resume for computer actions and avoid duplicate rejections
1 parent 7554dba commit 5f0c056

File tree

2 files changed

+276
-3
lines changed

2 files changed

+276
-3
lines changed

src/agents/run_internal/run_loop.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -711,11 +711,11 @@ async def _collect_runs_by_approval(
711711
existing_pending=existing_pending,
712712
)
713713

714-
if approval_status is False:
715-
rejection_items.append(rejection_builder(call_id))
714+
if output_exists_checker and output_exists_checker(call_id):
716715
continue
717716

718-
if output_exists_checker and output_exists_checker(call_id):
717+
if approval_status is False:
718+
rejection_items.append(rejection_builder(call_id))
719719
continue
720720

721721
needs_approval = True
@@ -747,6 +747,12 @@ def _shell_call_id_from_run(run: ToolRunShellCall) -> str:
747747
def _apply_patch_call_id_from_run(run: ToolRunApplyPatchCall) -> str:
748748
return extract_apply_patch_call_id(run.tool_call)
749749

750+
def _computer_call_id_from_run(run: ToolRunComputerAction) -> str:
751+
call_id = extract_tool_call_id(run.tool_call)
752+
if not call_id:
753+
raise ModelBehaviorError("Computer action is missing call_id.")
754+
return call_id
755+
750756
def _shell_tool_name(run: ToolRunShellCall) -> str:
751757
return run.shell_tool.name
752758

@@ -784,6 +790,9 @@ def _shell_output_exists(call_id: str) -> bool:
784790
def _apply_patch_output_exists(call_id: str) -> bool:
785791
return _has_output_item(call_id, "apply_patch_call_output")
786792

793+
def _computer_output_exists(call_id: str) -> bool:
794+
return _has_output_item(call_id, "computer_call_output")
795+
787796
def _add_pending_interruption(item: ToolApprovalItem | None) -> None:
788797
if item is None:
789798
return
@@ -926,6 +935,23 @@ async def _rebuild_function_runs_from_approvals() -> list[ToolRunFunction]:
926935
for interruption in result.interruptions:
927936
_add_pending_interruption(interruption)
928937

938+
pending_computer_actions: list[ToolRunComputerAction] = []
939+
for action in processed_response.computer_actions:
940+
call_id = _computer_call_id_from_run(action)
941+
if _computer_output_exists(call_id):
942+
continue
943+
pending_computer_actions.append(action)
944+
945+
computer_results: list[RunItem] = []
946+
if pending_computer_actions:
947+
computer_results = await execute_computer_actions(
948+
agent=agent,
949+
actions=pending_computer_actions,
950+
hooks=hooks,
951+
context_wrapper=context_wrapper,
952+
config=run_config,
953+
)
954+
929955
# Execute shell/apply_patch only when approved; emit rejections otherwise.
930956
approved_shell_calls, rejected_shell_results = await _collect_runs_by_approval(
931957
processed_response.shell_calls,
@@ -975,6 +1001,8 @@ def append_if_new(item: RunItem) -> None:
9751001

9761002
for function_result in function_results:
9771003
append_if_new(function_result.run_item)
1004+
for computer_result in computer_results:
1005+
append_if_new(computer_result)
9781006
for rejection_item in rejected_function_outputs:
9791007
append_if_new(rejection_item)
9801008
for pending_item in pending_interruptions:

tests/test_hitl_error_scenarios.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Any, Callable, Optional, cast
66

77
import pytest
8+
from openai.types.responses import ResponseComputerToolCall
9+
from openai.types.responses.response_computer_tool_call import ActionScreenshot
810
from openai.types.responses.response_input_param import (
911
ComputerCallOutput,
1012
LocalShellCallOutput,
@@ -14,6 +16,7 @@
1416
from agents import (
1517
Agent,
1618
ApplyPatchTool,
19+
ComputerTool,
1720
LocalShellTool,
1821
Runner,
1922
RunResult,
@@ -22,6 +25,7 @@
2225
ToolApprovalItem,
2326
function_tool,
2427
)
28+
from agents.computer import Computer, Environment
2529
from agents.exceptions import ModelBehaviorError, UserError
2630
from agents.items import (
2731
MCPApprovalResponseItem,
@@ -38,6 +42,7 @@
3842
NextStepInterruption,
3943
NextStepRunAgain,
4044
ProcessedResponse,
45+
ToolRunComputerAction,
4146
ToolRunFunction,
4247
ToolRunMCPApprovalRequest,
4348
ToolRunShellCall,
@@ -76,6 +81,49 @@
7681
)
7782

7883

84+
class TrackingComputer(Computer):
85+
"""Minimal computer implementation that records method calls."""
86+
87+
def __init__(self) -> None:
88+
self.calls: list[str] = []
89+
90+
@property
91+
def environment(self) -> Environment:
92+
return "mac"
93+
94+
@property
95+
def dimensions(self) -> tuple[int, int]:
96+
return (1, 1)
97+
98+
def screenshot(self) -> str:
99+
self.calls.append("screenshot")
100+
return "img"
101+
102+
def click(self, _x: int, _y: int, _button: str) -> None:
103+
self.calls.append("click")
104+
105+
def double_click(self, _x: int, _y: int) -> None:
106+
self.calls.append("double_click")
107+
108+
def scroll(self, _x: int, _y: int, _scroll_x: int, _scroll_y: int) -> None:
109+
self.calls.append("scroll")
110+
111+
def type(self, _text: str) -> None:
112+
self.calls.append("type")
113+
114+
def wait(self) -> None:
115+
self.calls.append("wait")
116+
117+
def move(self, _x: int, _y: int) -> None:
118+
self.calls.append("move")
119+
120+
def keypress(self, _keys: list[str]) -> None:
121+
self.calls.append("keypress")
122+
123+
def drag(self, _path: list[tuple[int, int]]) -> None:
124+
self.calls.append("drag")
125+
126+
79127
def _shell_approval_setup() -> ApprovalScenario:
80128
tool = ShellTool(executor=lambda request: "shell_output", needs_approval=require_approval)
81129
shell_call = make_shell_call("call_shell_1", id_value="shell_1", commands=["echo test"])
@@ -889,6 +937,123 @@ async def test_resume_skips_shell_calls_with_existing_output() -> None:
889937
assert not result.new_step_items, "Shell call should not run when output already exists"
890938

891939

940+
@pytest.mark.asyncio
941+
async def test_resume_executes_pending_computer_actions() -> None:
942+
"""Pending computer actions should execute when resuming an interrupted turn."""
943+
944+
computer = TrackingComputer()
945+
computer_tool = ComputerTool(computer=computer)
946+
model, agent = make_model_and_agent(tools=[computer_tool])
947+
948+
computer_call = ResponseComputerToolCall(
949+
type="computer_call",
950+
id="comp_pending",
951+
call_id="comp_pending",
952+
status="in_progress",
953+
action=ActionScreenshot(type="screenshot"),
954+
pending_safety_checks=[],
955+
)
956+
957+
processed_response = ProcessedResponse(
958+
new_items=[],
959+
handoffs=[],
960+
functions=[],
961+
computer_actions=[
962+
ToolRunComputerAction(tool_call=computer_call, computer_tool=computer_tool)
963+
],
964+
local_shell_calls=[],
965+
shell_calls=[],
966+
apply_patch_calls=[],
967+
tools_used=[computer_tool.name],
968+
mcp_approval_requests=[],
969+
interruptions=[],
970+
)
971+
972+
result = await run_loop.resolve_interrupted_turn(
973+
agent=agent,
974+
original_input="resume computer",
975+
original_pre_step_items=[],
976+
new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"),
977+
processed_response=processed_response,
978+
hooks=RunHooks(),
979+
context_wrapper=make_context_wrapper(),
980+
run_config=RunConfig(),
981+
run_state=None,
982+
)
983+
984+
outputs = [
985+
item
986+
for item in result.new_step_items
987+
if isinstance(item, ToolCallOutputItem)
988+
and isinstance(item.raw_item, dict)
989+
and item.raw_item.get("type") == "computer_call_output"
990+
]
991+
assert outputs, "Computer action should run when resuming without prior output"
992+
assert computer.calls, "Computer should have been invoked"
993+
assert isinstance(result.next_step, NextStepRunAgain)
994+
995+
996+
@pytest.mark.asyncio
997+
async def test_resume_skips_computer_actions_with_existing_output() -> None:
998+
"""Computer actions with persisted output should not execute again when resuming."""
999+
1000+
computer = TrackingComputer()
1001+
computer_tool = ComputerTool(computer=computer)
1002+
model, agent = make_model_and_agent(tools=[computer_tool])
1003+
1004+
computer_call = ResponseComputerToolCall(
1005+
type="computer_call",
1006+
id="comp_skip",
1007+
call_id="comp_skip",
1008+
status="completed",
1009+
action=ActionScreenshot(type="screenshot"),
1010+
pending_safety_checks=[],
1011+
)
1012+
1013+
processed_response = ProcessedResponse(
1014+
new_items=[],
1015+
handoffs=[],
1016+
functions=[],
1017+
computer_actions=[
1018+
ToolRunComputerAction(tool_call=computer_call, computer_tool=computer_tool)
1019+
],
1020+
local_shell_calls=[],
1021+
shell_calls=[],
1022+
apply_patch_calls=[],
1023+
tools_used=[computer_tool.name],
1024+
mcp_approval_requests=[],
1025+
interruptions=[],
1026+
)
1027+
1028+
original_pre_step_items = [
1029+
ToolCallOutputItem(
1030+
agent=agent,
1031+
raw_item={
1032+
"type": "computer_call_output",
1033+
"call_id": "comp_skip",
1034+
"output": {"type": "computer_screenshot", "image_url": ""},
1035+
},
1036+
output="image_url",
1037+
)
1038+
]
1039+
1040+
result = await run_loop.resolve_interrupted_turn(
1041+
agent=agent,
1042+
original_input="resume computer existing",
1043+
original_pre_step_items=cast(list[RunItem], original_pre_step_items),
1044+
new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"),
1045+
processed_response=processed_response,
1046+
hooks=RunHooks(),
1047+
context_wrapper=make_context_wrapper(),
1048+
run_config=RunConfig(),
1049+
run_state=None,
1050+
)
1051+
1052+
assert not computer.calls, "Computer action should not run when output already exists"
1053+
assert not result.new_step_items, "No new items should be emitted when output exists"
1054+
assert isinstance(result.next_step, NextStepRunAgain)
1055+
1056+
8921057
@pytest.mark.asyncio
8931058
async def test_rebuild_function_runs_handles_pending_and_rejections() -> None:
8941059
"""Rebuilt function runs should surface pending approvals and emit rejections."""
@@ -1017,6 +1182,86 @@ async def test_rejected_shell_calls_emit_rejection_output() -> None:
10171182
assert isinstance(result.next_step, NextStepRunAgain)
10181183

10191184

1185+
@pytest.mark.asyncio
1186+
async def test_rejected_shell_calls_with_existing_output_are_not_duplicated() -> None:
1187+
"""Rejected shell calls with persisted output should not emit duplicate rejections."""
1188+
1189+
shell_tool = ShellTool(executor=lambda _req: "should_not_run", needs_approval=True)
1190+
_model, agent = make_model_and_agent(tools=[shell_tool])
1191+
context_wrapper = make_context_wrapper()
1192+
1193+
shell_call = make_shell_call(
1194+
"call_reject_shell_dup",
1195+
id_value="shell_reject_dup",
1196+
commands=["echo test"],
1197+
status="in_progress",
1198+
)
1199+
approval_item = ToolApprovalItem(
1200+
agent=agent,
1201+
raw_item=cast(dict[str, Any], shell_call),
1202+
tool_name=shell_tool.name,
1203+
)
1204+
context_wrapper.reject_tool(approval_item)
1205+
1206+
processed_response = ProcessedResponse(
1207+
new_items=[],
1208+
handoffs=[],
1209+
functions=[],
1210+
computer_actions=[],
1211+
local_shell_calls=[],
1212+
shell_calls=[ToolRunShellCall(tool_call=shell_call, shell_tool=shell_tool)],
1213+
apply_patch_calls=[],
1214+
tools_used=[],
1215+
mcp_approval_requests=[],
1216+
interruptions=[],
1217+
)
1218+
1219+
original_pre_step_items = [
1220+
ToolCallOutputItem(
1221+
agent=agent,
1222+
raw_item=cast(
1223+
dict[str, Any],
1224+
{
1225+
"type": "shell_call_output",
1226+
"call_id": "call_reject_shell_dup",
1227+
"output": [
1228+
{
1229+
"stdout": "",
1230+
"stderr": HITL_REJECTION_MSG,
1231+
"outcome": {"type": "exit", "exit_code": 1},
1232+
}
1233+
],
1234+
},
1235+
),
1236+
output=HITL_REJECTION_MSG,
1237+
)
1238+
]
1239+
1240+
result = await run_loop.resolve_interrupted_turn(
1241+
agent=agent,
1242+
original_input="resume shell rejection existing",
1243+
original_pre_step_items=cast(list[RunItem], original_pre_step_items),
1244+
new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"),
1245+
processed_response=processed_response,
1246+
hooks=RunHooks(),
1247+
context_wrapper=context_wrapper,
1248+
run_config=RunConfig(),
1249+
run_state=None,
1250+
)
1251+
1252+
duplicate_rejections = [
1253+
item
1254+
for item in result.new_step_items
1255+
if isinstance(item, ToolCallOutputItem)
1256+
and isinstance(item.raw_item, dict)
1257+
and item.raw_item.get("type") == "shell_call_output"
1258+
and HITL_REJECTION_MSG in str(item.output)
1259+
]
1260+
1261+
assert not duplicate_rejections, "No duplicate rejection outputs should be emitted"
1262+
assert isinstance(result.next_step, NextStepRunAgain)
1263+
1264+
10201265
@pytest.mark.asyncio
10211266
async def test_mcp_callback_approvals_are_processed() -> None:
10221267
"""MCP approval requests with callbacks should emit approval responses."""

0 commit comments

Comments
 (0)