Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
ActionType,
ActionWait,
)
from openai.types.responses.response_input_item_param import (
ComputerCallOutputAcknowledgedSafetyCheck,
)
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
from openai.types.responses.response_output_item import (
ImageGenerationCall,
Expand Down Expand Up @@ -67,6 +70,7 @@
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import (
ComputerTool,
ComputerToolSafetyCheckData,
FunctionTool,
FunctionToolResult,
HostedMCPTool,
Expand Down Expand Up @@ -638,13 +642,37 @@ async def execute_computer_actions(
results: list[RunItem] = []
# Need to run these serially, because each action can affect the computer state
for action in actions:
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
acknowledged = []
for check in action.tool_call.pending_safety_checks:
data = ComputerToolSafetyCheckData(
ctx_wrapper=context_wrapper,
agent=agent,
tool_call=action.tool_call,
safety_check=check,
)
maybe = action.computer_tool.on_safety_check(data)
ack = await maybe if inspect.isawaitable(maybe) else maybe
if ack:
acknowledged.append(
ComputerCallOutputAcknowledgedSafetyCheck(
id=check.id,
code=check.code,
message=check.message,
)
)
else:
raise UserError("Computer tool safety check was not acknowledged")

results.append(
await ComputerAction.execute(
agent=agent,
action=action,
hooks=hooks,
context_wrapper=context_wrapper,
config=config,
acknowledged_safety_checks=acknowledged,
)
)

Expand Down Expand Up @@ -998,6 +1026,7 @@ async def execute(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None,
) -> RunItem:
output_func = (
cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
Expand Down Expand Up @@ -1036,6 +1065,7 @@ async def execute(
"image_url": image_url,
},
type="computer_call_output",
acknowledged_safety_checks=acknowledged_safety_checks,
),
)

Expand Down
24 changes: 24 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload

from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_computer_tool_call import (
PendingSafetyCheck,
ResponseComputerToolCall,
)
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp
from openai.types.responses.web_search_tool_param import UserLocation
Expand Down Expand Up @@ -141,11 +145,31 @@ class ComputerTool:
as well as implements the computer actions like click, screenshot, etc.
"""

on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None
"""Optional callback to acknowledge computer tool safety checks."""

@property
def name(self):
return "computer_use_preview"


@dataclass
class ComputerToolSafetyCheckData:
"""Information about a computer tool safety check."""

ctx_wrapper: RunContextWrapper[Any]
"""The run context."""

agent: Agent[Any]
"""The agent performing the computer action."""

tool_call: ResponseComputerToolCall
"""The computer tool call."""

safety_check: PendingSafetyCheck
"""The pending safety check to acknowledge."""


@dataclass
class MCPToolApprovalRequest:
"""A request to approve a tool call."""
Expand Down
45 changes: 44 additions & 1 deletion tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ActionScroll,
ActionType,
ActionWait,
PendingSafetyCheck,
ResponseComputerToolCall,
)

Expand All @@ -31,8 +32,9 @@
RunContextWrapper,
RunHooks,
)
from agents._run_impl import ComputerAction, ToolRunComputerAction
from agents._run_impl import ComputerAction, RunImpl, ToolRunComputerAction
from agents.items import ToolCallOutputItem
from agents.tool import ComputerToolSafetyCheckData


class LoggingComputer(Computer):
Expand Down Expand Up @@ -309,3 +311,44 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None:
assert raw["output"]["type"] == "computer_screenshot"
assert "image_url" in raw["output"]
assert raw["output"]["image_url"].endswith("xyz")


@pytest.mark.asyncio
async def test_pending_safety_check_acknowledged() -> None:
"""Safety checks should be acknowledged via the callback."""

computer = LoggingComputer(screenshot_return="img")
called: list[ComputerToolSafetyCheckData] = []

def on_sc(data: ComputerToolSafetyCheckData) -> bool:
called.append(data)
return True

tool = ComputerTool(computer=computer, on_safety_check=on_sc)
safety = PendingSafetyCheck(id="sc", code="c", message="m")
tool_call = ResponseComputerToolCall(
id="t1",
type="computer_call",
action=ActionClick(type="click", x=1, y=1, button="left"),
call_id="t1",
pending_safety_checks=[safety],
status="completed",
)
run_action = ToolRunComputerAction(tool_call=tool_call, computer_tool=tool)
agent = Agent(name="a", tools=[tool])
ctx = RunContextWrapper(context=None)

results = await RunImpl.execute_computer_actions(
agent=agent,
actions=[run_action],
hooks=RunHooks[Any](),
context_wrapper=ctx,
config=RunConfig(),
)

assert len(results) == 1
raw = results[0].raw_item
assert isinstance(raw, dict)
assert raw.get("acknowledged_safety_checks") == [{"id": "sc", "code": "c", "message": "m"}]
assert len(called) == 1
assert called[0].safety_check.id == "sc"