Skip to content

Commit 5ce726a

Browse files
committed
fix: improving parity with openai-agent-js hitl functionality
1 parent a97b01b commit 5ce726a

File tree

12 files changed

+1127
-53
lines changed

12 files changed

+1127
-53
lines changed

examples/agent_patterns/human_in_the_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ async def main():
113113

114114
print("\nTool call details:")
115115
print(f" Agent: {interruption.agent.name}")
116-
print(f" Tool: {interruption.raw_item.name}")
117-
print(f" Arguments: {interruption.raw_item.arguments}")
116+
print(f" Tool: {interruption.name}")
117+
print(f" Arguments: {interruption.arguments}")
118118

119119
confirmed = await confirm("\nDo you approve this tool call?")
120120

121121
if confirmed:
122-
print(f"✓ Approved: {interruption.raw_item.name}")
122+
print(f"✓ Approved: {interruption.name}")
123123
state.approve(interruption)
124124
else:
125-
print(f"✗ Rejected: {interruption.raw_item.name}")
125+
print(f"✗ Rejected: {interruption.name}")
126126
state.reject(interruption)
127127

128128
# Resume execution with the updated state

examples/agent_patterns/human_in_the_loop_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ async def main():
9494

9595
print("\nTool call details:")
9696
print(f" Agent: {interruption.agent.name}")
97-
print(f" Tool: {interruption.raw_item.name}")
98-
print(f" Arguments: {interruption.raw_item.arguments}")
97+
print(f" Tool: {interruption.name}")
98+
print(f" Arguments: {interruption.arguments}")
9999

100100
confirmed = await confirm("\nDo you approve this tool call?")
101101

102102
if confirmed:
103-
print(f"✓ Approved: {interruption.raw_item.name}")
103+
print(f"✓ Approved: {interruption.name}")
104104
state.approve(interruption)
105105
else:
106-
print(f"✗ Rejected: {interruption.raw_item.name}")
106+
print(f"✗ Rejected: {interruption.name}")
107107
state.reject(interruption)
108108

109109
# Resume execution with streaming

src/agents/_run_impl.py

Lines changed: 152 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -356,37 +356,51 @@ async def execute_tools_and_side_effects(
356356
config=run_config,
357357
),
358358
)
359-
# Check for tool approval interruptions before adding items
359+
# Add all tool results to new_step_items first, including approval items.
360+
# This ensures ToolCallItem items from processed_response.new_items are preserved
361+
# in the conversation history when resuming after an interruption.
360362
from .items import ToolApprovalItem
361363

364+
# Add all function results (including approval items) to new_step_items
365+
for result in function_results:
366+
new_step_items.append(result.run_item)
367+
368+
# Add all other tool results
369+
new_step_items.extend(computer_results)
370+
for shell_result in shell_results:
371+
new_step_items.append(shell_result)
372+
for apply_patch_result in apply_patch_results:
373+
new_step_items.append(apply_patch_result)
374+
new_step_items.extend(local_shell_results)
375+
376+
# Check for interruptions after adding all items
362377
interruptions: list[RunItem] = []
363-
approved_function_results = []
364378
for result in function_results:
365379
if isinstance(result.run_item, ToolApprovalItem):
366380
interruptions.append(result.run_item)
367-
else:
368-
approved_function_results.append(result)
381+
for shell_result in shell_results:
382+
if isinstance(shell_result, ToolApprovalItem):
383+
interruptions.append(shell_result)
384+
for apply_patch_result in apply_patch_results:
385+
if isinstance(apply_patch_result, ToolApprovalItem):
386+
interruptions.append(apply_patch_result)
369387

370388
# If there are interruptions, return immediately without executing remaining tools
371389
if interruptions:
372-
# Return the interruption step
390+
# new_step_items already contains:
391+
# 1. processed_response.new_items (added at line 312) - includes ToolCallItem items
392+
# 2. All tool results including approval items (added above)
393+
# This ensures ToolCallItem items are preserved in conversation history when resuming
373394
return SingleStepResult(
374395
original_input=original_input,
375396
model_response=new_response,
376397
pre_step_items=pre_step_items,
377-
new_step_items=interruptions,
398+
new_step_items=new_step_items,
378399
next_step=NextStepInterruption(interruptions=interruptions),
379400
tool_input_guardrail_results=tool_input_guardrail_results,
380401
tool_output_guardrail_results=tool_output_guardrail_results,
381402
processed_response=processed_response,
382403
)
383-
384-
new_step_items.extend([result.run_item for result in approved_function_results])
385-
new_step_items.extend(computer_results)
386-
new_step_items.extend(shell_results)
387-
new_step_items.extend(apply_patch_results)
388-
new_step_items.extend(local_shell_results)
389-
390404
# Next, run the MCP approval requests
391405
if processed_response.mcp_approval_requests:
392406
approval_results = await cls.execute_mcp_approval_requests(
@@ -999,7 +1013,9 @@ async def run_single_tool(
9991013
# Not yet decided - need to interrupt for approval
10001014
from .items import ToolApprovalItem
10011015

1002-
approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call)
1016+
approval_item = ToolApprovalItem(
1017+
agent=agent, raw_item=tool_call, tool_name=func_tool.name
1018+
)
10031019
return FunctionToolResult(
10041020
tool=func_tool, output=None, run_item=approval_item
10051021
)
@@ -1800,16 +1816,75 @@ async def execute(
18001816
context_wrapper: RunContextWrapper[TContext],
18011817
config: RunConfig,
18021818
) -> RunItem:
1819+
shell_call = _coerce_shell_call(call.tool_call)
1820+
shell_tool = call.shell_tool
1821+
1822+
# Check if approval is needed
1823+
needs_approval_result: bool = False
1824+
if isinstance(shell_tool.needs_approval, bool):
1825+
needs_approval_result = shell_tool.needs_approval
1826+
elif callable(shell_tool.needs_approval):
1827+
maybe_awaitable = shell_tool.needs_approval(
1828+
context_wrapper, shell_call.action, shell_call.call_id
1829+
)
1830+
needs_approval_result = (
1831+
await maybe_awaitable if inspect.isawaitable(maybe_awaitable) else maybe_awaitable
1832+
)
1833+
1834+
if needs_approval_result:
1835+
# Create approval item with explicit tool name
1836+
approval_item = ToolApprovalItem(
1837+
agent=agent, raw_item=call.tool_call, tool_name=shell_tool.name
1838+
)
1839+
1840+
# Handle on_approval callback if provided
1841+
if shell_tool.on_approval:
1842+
maybe_awaitable_decision = shell_tool.on_approval(context_wrapper, approval_item)
1843+
decision = (
1844+
await maybe_awaitable_decision
1845+
if inspect.isawaitable(maybe_awaitable_decision)
1846+
else maybe_awaitable_decision
1847+
)
1848+
if decision.get("approve") is True:
1849+
context_wrapper.approve_tool(approval_item)
1850+
elif decision.get("approve") is False:
1851+
context_wrapper.reject_tool(approval_item)
1852+
1853+
# Check approval status
1854+
approval_status = context_wrapper.is_tool_approved(shell_tool.name, shell_call.call_id)
1855+
1856+
if approval_status is False:
1857+
# Rejected - return rejection output
1858+
response = "Tool execution was not approved."
1859+
rejection_output: dict[str, Any] = {
1860+
"stdout": "",
1861+
"stderr": response,
1862+
"outcome": {"type": "exit", "exitCode": None},
1863+
}
1864+
rejection_raw_item: dict[str, Any] = {
1865+
"type": "shell_call_output",
1866+
"call_id": shell_call.call_id,
1867+
"output": [rejection_output],
1868+
}
1869+
return ToolCallOutputItem(
1870+
agent=agent,
1871+
output=response,
1872+
raw_item=cast(Any, rejection_raw_item),
1873+
)
1874+
1875+
if approval_status is not True:
1876+
# Pending approval - return approval item
1877+
return approval_item
1878+
1879+
# Approved or no approval needed - proceed with execution
18031880
await asyncio.gather(
1804-
hooks.on_tool_start(context_wrapper, agent, call.shell_tool),
1881+
hooks.on_tool_start(context_wrapper, agent, shell_tool),
18051882
(
1806-
agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool)
1883+
agent.hooks.on_tool_start(context_wrapper, agent, shell_tool)
18071884
if agent.hooks
18081885
else _coro.noop_coroutine()
18091886
),
18101887
)
1811-
1812-
shell_call = _coerce_shell_call(call.tool_call)
18131888
request = ShellCommandRequest(ctx_wrapper=context_wrapper, data=shell_call)
18141889
status: Literal["completed", "failed"] = "completed"
18151890
output_text = ""
@@ -1924,6 +1999,65 @@ async def execute(
19241999
config: RunConfig,
19252000
) -> RunItem:
19262001
apply_patch_tool = call.apply_patch_tool
2002+
operation = _coerce_apply_patch_operation(call.tool_call)
2003+
2004+
# Extract call_id from tool_call
2005+
call_id = _extract_apply_patch_call_id(call.tool_call)
2006+
2007+
# Check if approval is needed
2008+
needs_approval_result: bool = False
2009+
if isinstance(apply_patch_tool.needs_approval, bool):
2010+
needs_approval_result = apply_patch_tool.needs_approval
2011+
elif callable(apply_patch_tool.needs_approval):
2012+
maybe_awaitable = apply_patch_tool.needs_approval(context_wrapper, operation, call_id)
2013+
needs_approval_result = (
2014+
await maybe_awaitable if inspect.isawaitable(maybe_awaitable) else maybe_awaitable
2015+
)
2016+
2017+
if needs_approval_result:
2018+
# Create approval item with explicit tool name
2019+
approval_item = ToolApprovalItem(
2020+
agent=agent, raw_item=call.tool_call, tool_name=apply_patch_tool.name
2021+
)
2022+
2023+
# Handle on_approval callback if provided
2024+
if apply_patch_tool.on_approval:
2025+
maybe_awaitable_decision = apply_patch_tool.on_approval(
2026+
context_wrapper, approval_item
2027+
)
2028+
decision = (
2029+
await maybe_awaitable_decision
2030+
if inspect.isawaitable(maybe_awaitable_decision)
2031+
else maybe_awaitable_decision
2032+
)
2033+
if decision.get("approve") is True:
2034+
context_wrapper.approve_tool(approval_item)
2035+
elif decision.get("approve") is False:
2036+
context_wrapper.reject_tool(approval_item)
2037+
2038+
# Check approval status
2039+
approval_status = context_wrapper.is_tool_approved(apply_patch_tool.name, call_id)
2040+
2041+
if approval_status is False:
2042+
# Rejected - return rejection output
2043+
response = "Tool execution was not approved."
2044+
rejection_raw_item: dict[str, Any] = {
2045+
"type": "apply_patch_call_output",
2046+
"call_id": call_id,
2047+
"status": "failed",
2048+
"output": response,
2049+
}
2050+
return ToolCallOutputItem(
2051+
agent=agent,
2052+
output=response,
2053+
raw_item=cast(Any, rejection_raw_item),
2054+
)
2055+
2056+
if approval_status is not True:
2057+
# Pending approval - return approval item
2058+
return approval_item
2059+
2060+
# Approved or no approval needed - proceed with execution
19272061
await asyncio.gather(
19282062
hooks.on_tool_start(context_wrapper, agent, apply_patch_tool),
19292063
(
@@ -1937,7 +2071,6 @@ async def execute(
19372071
output_text = ""
19382072

19392073
try:
1940-
operation = _coerce_apply_patch_operation(call.tool_call)
19412074
editor = apply_patch_tool.editor
19422075
if operation.type == "create_file":
19432076
result = editor.create_file(operation)

src/agents/items.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,71 @@ class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]):
327327
type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item"
328328

329329

330+
# Union type for tool approval raw items - supports function tools, hosted tools, shell tools, etc.
331+
ToolApprovalRawItem: TypeAlias = Union[
332+
ResponseFunctionToolCall,
333+
McpCall,
334+
LocalShellCall,
335+
dict[str, Any], # For flexibility with other tool types
336+
]
337+
338+
330339
@dataclass
331-
class ToolApprovalItem(RunItemBase[ResponseFunctionToolCall]):
340+
class ToolApprovalItem(RunItemBase[Any]):
332341
"""Represents a tool call that requires approval before execution.
333342
334343
When a tool has `needs_approval=True`, the run will be interrupted and this item will be
335344
added to the interruptions list. You can then approve or reject the tool call using
336345
RunState.approve() or RunState.reject() and resume the run.
337346
"""
338347

339-
raw_item: ResponseFunctionToolCall
340-
"""The raw function tool call that requires approval."""
348+
raw_item: ToolApprovalRawItem
349+
"""The raw tool call that requires approval. Can be a function tool call, hosted tool call,
350+
shell call, or other tool type.
351+
"""
352+
353+
tool_name: str | None = None
354+
"""Explicit tool name to use for approval tracking when not present on the raw item.
355+
If not provided, falls back to raw_item.name.
356+
"""
341357

342358
type: Literal["tool_approval_item"] = "tool_approval_item"
343359

360+
def __post_init__(self) -> None:
361+
"""Set tool_name from raw_item.name if not explicitly provided."""
362+
if self.tool_name is None:
363+
# Extract name from raw_item - handle different types
364+
if isinstance(self.raw_item, dict):
365+
self.tool_name = self.raw_item.get("name")
366+
elif hasattr(self.raw_item, "name"):
367+
self.tool_name = self.raw_item.name
368+
else:
369+
self.tool_name = None
370+
371+
@property
372+
def name(self) -> str | None:
373+
"""Returns the tool name if available on the raw item or provided explicitly.
374+
375+
Kept for backwards compatibility with code that previously relied on raw_item.name.
376+
"""
377+
return self.tool_name or (
378+
getattr(self.raw_item, "name", None)
379+
if not isinstance(self.raw_item, dict)
380+
else self.raw_item.get("name")
381+
)
382+
383+
@property
384+
def arguments(self) -> str | None:
385+
"""Returns the arguments if the raw item has an arguments property, otherwise None.
386+
387+
This provides a safe way to access tool call arguments regardless of the raw_item type.
388+
"""
389+
if isinstance(self.raw_item, dict):
390+
return self.raw_item.get("arguments")
391+
elif hasattr(self.raw_item, "arguments"):
392+
return self.raw_item.arguments
393+
return None
394+
344395

345396
RunItem: TypeAlias = Union[
346397
MessageOutputItem,

0 commit comments

Comments
 (0)