Skip to content

Commit a1b06b9

Browse files
committed
fix: rebuild HITL function runs from object approvals
1 parent aaa0e4c commit a1b06b9

File tree

2 files changed

+168
-39
lines changed

2 files changed

+168
-39
lines changed

src/agents/run_internal/run_loop.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -837,44 +837,57 @@ async def _rebuild_function_runs_from_approvals() -> list[ToolRunFunction]:
837837
if not isinstance(approval, ToolApprovalItem):
838838
continue
839839
raw = approval.raw_item
840-
if isinstance(raw, dict) and raw.get("type") == "function_call":
841-
name = raw.get("name")
842-
if name and isinstance(name, str) and name in tool_map:
843-
rebuilt_call_id = extract_tool_call_id(raw)
844-
arguments = raw.get("arguments", "{}")
845-
status = raw.get("status")
846-
if isinstance(rebuilt_call_id, str) and isinstance(arguments, str):
847-
# Validate status is a valid Literal type
848-
valid_status: Literal["in_progress", "completed", "incomplete"] | None = (
849-
None
850-
)
851-
if isinstance(status, str) and status in (
852-
"in_progress",
853-
"completed",
854-
"incomplete",
855-
):
856-
valid_status = status # type: ignore[assignment]
857-
tool_call = ResponseFunctionToolCall(
858-
type="function_call",
859-
name=name,
860-
call_id=rebuilt_call_id,
861-
arguments=arguments,
862-
status=valid_status,
863-
)
864-
approval_status = context_wrapper.get_approval_status(
865-
name, rebuilt_call_id, existing_pending=approval
866-
)
867-
if approval_status is False:
868-
_record_function_rejection(rebuilt_call_id, tool_call)
869-
continue
870-
if approval_status is None:
871-
if rebuilt_call_id not in existing_pending_call_ids:
872-
_add_pending_interruption(approval)
873-
existing_pending_call_ids.add(rebuilt_call_id)
874-
continue
875-
rebuilt_runs.append(
876-
ToolRunFunction(function_tool=tool_map[name], tool_call=tool_call)
877-
)
840+
raw_type = get_mapping_or_attr(raw, "type")
841+
if raw_type != "function_call":
842+
continue
843+
name = get_mapping_or_attr(raw, "name")
844+
if not (isinstance(name, str) and name in tool_map):
845+
continue
846+
847+
rebuilt_call_id: str | None
848+
arguments: str | None
849+
tool_call: ResponseFunctionToolCall
850+
if isinstance(raw, ResponseFunctionToolCall):
851+
rebuilt_call_id = raw.call_id
852+
arguments = raw.arguments
853+
tool_call = raw
854+
else:
855+
rebuilt_call_id = extract_tool_call_id(raw)
856+
arguments = get_mapping_or_attr(raw, "arguments") or "{}"
857+
status = get_mapping_or_attr(raw, "status")
858+
if not (isinstance(rebuilt_call_id, str) and isinstance(arguments, str)):
859+
continue
860+
# Validate status is a valid Literal type
861+
valid_status: Literal["in_progress", "completed", "incomplete"] | None = None
862+
if isinstance(status, str) and status in (
863+
"in_progress",
864+
"completed",
865+
"incomplete",
866+
):
867+
valid_status = status # type: ignore[assignment]
868+
tool_call = ResponseFunctionToolCall(
869+
type="function_call",
870+
name=name,
871+
call_id=rebuilt_call_id,
872+
arguments=arguments,
873+
status=valid_status,
874+
)
875+
876+
if not (isinstance(rebuilt_call_id, str) and isinstance(arguments, str)):
877+
continue
878+
879+
approval_status = context_wrapper.get_approval_status(
880+
name, rebuilt_call_id, existing_pending=approval
881+
)
882+
if approval_status is False:
883+
_record_function_rejection(rebuilt_call_id, tool_call)
884+
continue
885+
if approval_status is None:
886+
if rebuilt_call_id not in existing_pending_call_ids:
887+
_add_pending_interruption(approval)
888+
existing_pending_call_ids.add(rebuilt_call_id)
889+
continue
890+
rebuilt_runs.append(ToolRunFunction(function_tool=tool_map[name], tool_call=tool_call))
878891
return rebuilt_runs
879892

880893
# Run only the approved function calls for this turn; emit rejections for denied ones.

tests/test_hitl_error_scenarios.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Callable, Optional, cast
66

77
import pytest
8-
from openai.types.responses import ResponseComputerToolCall
8+
from openai.types.responses import ResponseComputerToolCall, ResponseFunctionToolCall
99
from openai.types.responses.response_computer_tool_call import ActionScreenshot
1010
from openai.types.responses.response_input_param import (
1111
ComputerCallOutput,
@@ -842,6 +842,122 @@ def approve_me(reason: Optional[str] = None) -> str: # noqa: UP007
842842
assert "call-rebuild-1" in executed_call_ids, "Function should be rebuilt and executed"
843843

844844

845+
@pytest.mark.asyncio
846+
async def test_resume_rebuilds_function_runs_from_object_approvals() -> None:
847+
"""Rebuild should handle ResponseFunctionToolCall approval items."""
848+
849+
@function_tool(needs_approval=True)
850+
def approve_me(reason: Optional[str] = None) -> str: # noqa: UP007
851+
return f"approved:{reason}" if reason else "approved"
852+
853+
model, agent = make_model_and_agent(tools=[approve_me])
854+
tool_call = make_function_tool_call(
855+
approve_me.name,
856+
call_id="call-rebuild-obj",
857+
arguments='{"reason": "ok"}',
858+
)
859+
assert isinstance(tool_call, ResponseFunctionToolCall)
860+
approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call)
861+
context_wrapper = make_context_wrapper()
862+
context_wrapper.approve_tool(approval_item)
863+
864+
run_state = make_state_with_interruptions(agent, [approval_item])
865+
processed_response = ProcessedResponse(
866+
new_items=[],
867+
handoffs=[],
868+
functions=[],
869+
computer_actions=[],
870+
local_shell_calls=[],
871+
shell_calls=[],
872+
apply_patch_calls=[],
873+
tools_used=[],
874+
mcp_approval_requests=[],
875+
interruptions=[],
876+
)
877+
878+
result = await run_loop.resolve_interrupted_turn(
879+
agent=agent,
880+
original_input="resume approvals",
881+
original_pre_step_items=[],
882+
new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"),
883+
processed_response=processed_response,
884+
hooks=RunHooks(),
885+
context_wrapper=context_wrapper,
886+
run_config=RunConfig(),
887+
run_state=run_state,
888+
)
889+
890+
assert not isinstance(result.next_step, NextStepInterruption)
891+
executed_call_ids = {
892+
extract_tool_call_id(item.raw_item)
893+
for item in result.new_step_items
894+
if isinstance(item, ToolCallOutputItem)
895+
}
896+
assert "call-rebuild-obj" in executed_call_ids, (
897+
"Function should be rebuilt from ResponseFunctionToolCall approval"
898+
)
899+
900+
901+
@pytest.mark.asyncio
902+
async def test_rebuild_function_runs_handles_object_pending_and_rejections() -> None:
903+
"""Rebuild should surface pending approvals and emit rejections for object approvals."""
904+
905+
@function_tool(needs_approval=True)
906+
def reject_me(text: str = "nope") -> str:
907+
return text
908+
909+
@function_tool(needs_approval=True)
910+
def pending_me(text: str = "wait") -> str:
911+
return text
912+
913+
_model, agent = make_model_and_agent(tools=[reject_me, pending_me])
914+
context_wrapper = make_context_wrapper()
915+
916+
rejected_call = make_function_tool_call(reject_me.name, call_id="obj-reject")
917+
pending_call = make_function_tool_call(pending_me.name, call_id="obj-pending")
918+
assert isinstance(rejected_call, ResponseFunctionToolCall)
919+
assert isinstance(pending_call, ResponseFunctionToolCall)
920+
921+
rejected_item = ToolApprovalItem(agent=agent, raw_item=rejected_call)
922+
pending_item = ToolApprovalItem(agent=agent, raw_item=pending_call)
923+
context_wrapper.reject_tool(rejected_item)
924+
925+
run_state = make_state_with_interruptions(agent, [rejected_item, pending_item])
926+
processed_response = ProcessedResponse(
927+
new_items=[],
928+
handoffs=[],
929+
functions=[],
930+
computer_actions=[],
931+
local_shell_calls=[],
932+
shell_calls=[],
933+
apply_patch_calls=[],
934+
tools_used=[],
935+
mcp_approval_requests=[],
936+
interruptions=[],
937+
)
938+
939+
result = await run_loop.resolve_interrupted_turn(
940+
agent=agent,
941+
original_input="resume approvals",
942+
original_pre_step_items=[],
943+
new_response=ModelResponse(output=[], usage=Usage(), response_id="resp"),
944+
processed_response=processed_response,
945+
hooks=RunHooks(),
946+
context_wrapper=context_wrapper,
947+
run_config=RunConfig(),
948+
run_state=run_state,
949+
)
950+
951+
assert isinstance(result.next_step, NextStepInterruption)
952+
assert pending_item in result.next_step.interruptions
953+
rejection_outputs = [
954+
item
955+
for item in result.new_step_items
956+
if isinstance(item, ToolCallOutputItem) and item.output == HITL_REJECTION_MSG
957+
]
958+
assert rejection_outputs, "Rejected function call should emit rejection output"
959+
960+
845961
@pytest.mark.asyncio
846962
async def test_resume_skips_non_hitl_function_calls() -> None:
847963
"""Non-HITL function calls should not re-run when resuming unrelated approvals."""

0 commit comments

Comments
 (0)