From ba4d4cf28f80e197a5a7bef9c1873fd7d70958fe Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 14 Nov 2025 02:43:01 +0900 Subject: [PATCH] Add new tools for gpt-5.1 --- examples/tools/apply_patch.py | 169 +++++ examples/tools/code_interpreter.py | 29 +- examples/tools/image_generator.py | 36 +- examples/tools/shell.py | 114 +++ examples/tools/web_search_filters.py | 32 +- src/agents/__init__.py | 31 +- src/agents/_run_impl.py | 718 +++++++++++++++++- src/agents/apply_diff.py | 329 ++++++++ src/agents/editor.py | 45 ++ src/agents/items.py | 38 +- src/agents/models/openai_responses.py | 8 + src/agents/run.py | 16 +- src/agents/tool.py | 100 ++- .../memory/test_dapr_redis_integration.py | 17 + tests/test_agents_logging.py | 13 + tests/test_apply_diff.py | 36 + tests/test_apply_diff_helpers.py | 73 ++ tests/test_apply_patch_tool.py | 139 ++++ tests/test_computer_action.py | 5 +- tests/test_function_tool.py | 15 + tests/test_run_step_execution.py | 15 +- tests/test_shell_call_serialization.py | 63 ++ tests/test_shell_tool.py | 137 ++++ tests/test_tool_metadata.py | 72 ++ 24 files changed, 2191 insertions(+), 59 deletions(-) create mode 100644 examples/tools/apply_patch.py create mode 100644 examples/tools/shell.py create mode 100644 src/agents/apply_diff.py create mode 100644 src/agents/editor.py create mode 100644 tests/test_agents_logging.py create mode 100644 tests/test_apply_diff.py create mode 100644 tests/test_apply_diff_helpers.py create mode 100644 tests/test_apply_patch_tool.py create mode 100644 tests/test_shell_call_serialization.py create mode 100644 tests/test_shell_tool.py create mode 100644 tests/test_tool_metadata.py diff --git a/examples/tools/apply_patch.py b/examples/tools/apply_patch.py new file mode 100644 index 000000000..19d0cfb7d --- /dev/null +++ b/examples/tools/apply_patch.py @@ -0,0 +1,169 @@ +import argparse +import asyncio +import hashlib +import os +import tempfile +from pathlib import Path + +from agents import Agent, ApplyPatchTool, ModelSettings, Runner, apply_diff, trace +from agents.editor import ApplyPatchOperation, ApplyPatchResult + + +class ApprovalTracker: + def __init__(self) -> None: + self._approved: set[str] = set() + + def fingerprint(self, operation: ApplyPatchOperation, relative_path: str) -> str: + hasher = hashlib.sha256() + hasher.update(operation.type.encode("utf-8")) + hasher.update(b"\0") + hasher.update(relative_path.encode("utf-8")) + hasher.update(b"\0") + hasher.update((operation.diff or "").encode("utf-8")) + return hasher.hexdigest() + + def remember(self, fingerprint: str) -> None: + self._approved.add(fingerprint) + + def is_approved(self, fingerprint: str) -> bool: + return fingerprint in self._approved + + +class WorkspaceEditor: + def __init__(self, root: Path, approvals: ApprovalTracker, auto_approve: bool) -> None: + self._root = root.resolve() + self._approvals = approvals + self._auto_approve = auto_approve or os.environ.get("APPLY_PATCH_AUTO_APPROVE") == "1" + + def create_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + relative = self._relative_path(operation.path) + self._require_approval(operation, relative) + target = self._resolve(operation.path, ensure_parent=True) + diff = operation.diff or "" + content = apply_diff("", diff, mode="create") + target.write_text(content, encoding="utf-8") + return ApplyPatchResult(output=f"Created {relative}") + + def update_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + relative = self._relative_path(operation.path) + self._require_approval(operation, relative) + target = self._resolve(operation.path) + original = target.read_text(encoding="utf-8") + diff = operation.diff or "" + patched = apply_diff(original, diff) + target.write_text(patched, encoding="utf-8") + return ApplyPatchResult(output=f"Updated {relative}") + + def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + relative = self._relative_path(operation.path) + self._require_approval(operation, relative) + target = self._resolve(operation.path) + target.unlink(missing_ok=True) + return ApplyPatchResult(output=f"Deleted {relative}") + + def _relative_path(self, value: str) -> str: + resolved = self._resolve(value) + return resolved.relative_to(self._root).as_posix() + + def _resolve(self, relative: str, ensure_parent: bool = False) -> Path: + candidate = Path(relative) + target = candidate if candidate.is_absolute() else (self._root / candidate) + target = target.resolve() + try: + target.relative_to(self._root) + except ValueError: + raise RuntimeError(f"Operation outside workspace: {relative}") from None + if ensure_parent: + target.parent.mkdir(parents=True, exist_ok=True) + return target + + def _require_approval(self, operation: ApplyPatchOperation, display_path: str) -> None: + fingerprint = self._approvals.fingerprint(operation, display_path) + if self._auto_approve or self._approvals.is_approved(fingerprint): + self._approvals.remember(fingerprint) + return + + print("\n[apply_patch] approval required") + print(f"- type: {operation.type}") + print(f"- path: {display_path}") + if operation.diff: + preview = operation.diff if len(operation.diff) < 400 else f"{operation.diff[:400]}…" + print("- diff preview:\n", preview) + answer = input("Proceed? [y/N] ").strip().lower() + if answer not in {"y", "yes"}: + raise RuntimeError("Apply patch operation rejected by user.") + self._approvals.remember(fingerprint) + + +async def main(auto_approve: bool, model: str) -> None: + with trace("apply_patch_example"): + with tempfile.TemporaryDirectory(prefix="apply-patch-example-") as workspace: + workspace_path = Path(workspace).resolve() + approvals = ApprovalTracker() + editor = WorkspaceEditor(workspace_path, approvals, auto_approve) + tool = ApplyPatchTool(editor=editor) + previous_response_id: str | None = None + + agent = Agent( + name="Patch Assistant", + model=model, + instructions=( + f"You can edit files inside {workspace_path} using the apply_patch tool. " + "When modifying an existing file, include the file contents between " + " and in your prompt." + ), + tools=[tool], + model_settings=ModelSettings(tool_choice="required"), + ) + + print(f"[info] Workspace root: {workspace_path}") + print(f"[info] Using model: {model}") + print("[run] Creating tasks.md") + result = await Runner.run( + agent, + "Create tasks.md with a shopping checklist of 5 entries.", + previous_response_id=previous_response_id, + ) + previous_response_id = result.last_response_id + print(f"[run] Final response #1:\n{result.final_output}\n") + notes_path = workspace_path / "tasks.md" + if not notes_path.exists(): + raise RuntimeError(f"{notes_path} was not created by the apply_patch tool.") + updated_notes = notes_path.read_text(encoding="utf-8") + print("[file] tasks.md after creation:\n") + print(updated_notes) + + prompt = ( + "\n" + f"===== tasks.md\n{updated_notes}\n" + "\n" + "Check off the last two items from the file." + ) + print("\n[run] Updating tasks.md") + result2 = await Runner.run( + agent, + prompt, + previous_response_id=previous_response_id, + ) + print(f"[run] Final response #2:\n{result2.final_output}\n") + if not notes_path.exists(): + raise RuntimeError("tasks.md vanished unexpectedly before the second read.") + print("[file] Final tasks.md:\n") + print(notes_path.read_text(encoding="utf-8")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--auto-approve", + action="store_true", + default=False, + help="Skip manual confirmations for apply_patch operations.", + ) + parser.add_argument( + "--model", + default="gpt-5.1", + help="Model ID to use for the agent.", + ) + args = parser.parse_args() + asyncio.run(main(args.auto_approve, args.model)) diff --git a/examples/tools/code_interpreter.py b/examples/tools/code_interpreter.py index 406e570e7..5fcc5f160 100644 --- a/examples/tools/code_interpreter.py +++ b/examples/tools/code_interpreter.py @@ -1,8 +1,16 @@ import asyncio +from collections.abc import Mapping +from typing import Any from agents import Agent, CodeInterpreterTool, Runner, trace +def _get_field(obj: Any, key: str) -> Any: + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key, None) + + async def main(): agent = Agent( name="Code interpreter", @@ -21,14 +29,19 @@ async def main(): print("Solving math problem...") result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?") async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and event.item.type == "tool_call_item" - and event.item.raw_item.type == "code_interpreter_call" - ): - print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n") - elif event.type == "run_item_stream_event": - print(f"Other event: {event.item.type}") + if event.type != "run_item_stream_event": + continue + + item = event.item + if item.type == "tool_call_item": + raw_call = item.raw_item + if _get_field(raw_call, "type") == "code_interpreter_call": + code = _get_field(raw_call, "code") + if isinstance(code, str): + print(f"Code interpreter code:\n```\n{code}\n```\n") + continue + + print(f"Other event: {event.item.type}") print(f"Final output: {result.final_output}") diff --git a/examples/tools/image_generator.py b/examples/tools/image_generator.py index 747b9ce92..399b51a47 100644 --- a/examples/tools/image_generator.py +++ b/examples/tools/image_generator.py @@ -4,10 +4,18 @@ import subprocess import sys import tempfile +from collections.abc import Mapping +from typing import Any from agents import Agent, ImageGenerationTool, Runner, trace +def _get_field(obj: Any, key: str) -> Any: + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key, None) + + def open_file(path: str) -> None: if sys.platform.startswith("darwin"): subprocess.run(["open", path], check=False) # macOS @@ -37,17 +45,23 @@ async def main(): ) print(result.final_output) for item in result.new_items: - if ( - item.type == "tool_call_item" - and item.raw_item.type == "image_generation_call" - and (img_result := item.raw_item.result) - ): - with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - tmp.write(base64.b64decode(img_result)) - temp_path = tmp.name - - # Open the image - open_file(temp_path) + if item.type != "tool_call_item": + continue + + raw_call = item.raw_item + call_type = _get_field(raw_call, "type") + if call_type != "image_generation_call": + continue + + img_result = _get_field(raw_call, "result") + if not isinstance(img_result, str): + continue + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp.write(base64.b64decode(img_result)) + temp_path = tmp.name + + open_file(temp_path) if __name__ == "__main__": diff --git a/examples/tools/shell.py b/examples/tools/shell.py new file mode 100644 index 000000000..7dcb13309 --- /dev/null +++ b/examples/tools/shell.py @@ -0,0 +1,114 @@ +import argparse +import asyncio +import os +from collections.abc import Sequence +from pathlib import Path + +from agents import ( + Agent, + ModelSettings, + Runner, + ShellCallOutcome, + ShellCommandOutput, + ShellCommandRequest, + ShellResult, + ShellTool, + trace, +) + + +class ShellExecutor: + """Executes shell commands with optional approval.""" + + def __init__(self, cwd: Path | None = None): + self.cwd = Path(cwd or Path.cwd()) + + async def __call__(self, request: ShellCommandRequest) -> ShellResult: + action = request.data.action + await require_approval(action.commands) + + outputs: list[ShellCommandOutput] = [] + for command in action.commands: + proc = await asyncio.create_subprocess_shell( + command, + cwd=self.cwd, + env=os.environ.copy(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + timed_out = False + try: + timeout = (action.timeout_ms or 0) / 1000 or None + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + stdout_bytes, stderr_bytes = await proc.communicate() + timed_out = True + + stdout = stdout_bytes.decode("utf-8", errors="ignore") + stderr = stderr_bytes.decode("utf-8", errors="ignore") + outputs.append( + ShellCommandOutput( + command=command, + stdout=stdout, + stderr=stderr, + outcome=ShellCallOutcome( + type="timeout" if timed_out else "exit", + exit_code=getattr(proc, "returncode", None), + ), + ) + ) + + if timed_out: + break + + return ShellResult( + output=outputs, + provider_data={"working_directory": str(self.cwd)}, + ) + + +async def require_approval(commands: Sequence[str]) -> None: + if os.environ.get("SHELL_AUTO_APPROVE") == "1": + return + print("Shell command approval required:") + for entry in commands: + print(" ", entry) + response = input("Proceed? [y/N] ").strip().lower() + if response not in {"y", "yes"}: + raise RuntimeError("Shell command execution rejected by user.") + + +async def main(prompt: str, model: str) -> None: + with trace("shell_example"): + print(f"[info] Using model: {model}") + agent = Agent( + name="Shell Assistant", + model=model, + instructions=( + "You can run shell commands using the shell tool. " + "Keep responses concise and include command output when helpful." + ), + tools=[ShellTool(executor=ShellExecutor())], + model_settings=ModelSettings(tool_choice="required"), + ) + + result = await Runner.run(agent, prompt) + print(f"\nFinal response:\n{result.final_output}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt", + default="Show the list of files in the current directory.", + help="Instruction to send to the agent.", + ) + parser.add_argument( + "--model", + default="gpt-5.1", + ) + args = parser.parse_args() + asyncio.run(main(args.prompt, args.model)) diff --git a/examples/tools/web_search_filters.py b/examples/tools/web_search_filters.py index 6be30b169..1e1ff0a11 100644 --- a/examples/tools/web_search_filters.py +++ b/examples/tools/web_search_filters.py @@ -1,11 +1,20 @@ import asyncio +from collections.abc import Mapping from datetime import datetime +from typing import Any from openai.types.responses.web_search_tool import Filters from openai.types.shared.reasoning import Reasoning from agents import Agent, ModelSettings, Runner, WebSearchTool, trace + +def _get_field(obj: Any, key: str) -> Any: + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key, None) + + # import logging # logging.basicConfig(level=logging.DEBUG) @@ -46,10 +55,25 @@ async def main(): print("### Sources ###") print() for item in result.new_items: - if item.type == "tool_call_item": - if item.raw_item.type == "web_search_call": - for source in item.raw_item.action.sources: # type: ignore [union-attr] - print(f"- {source.url}") + if item.type != "tool_call_item": + continue + + raw_call = item.raw_item + call_type = _get_field(raw_call, "type") + if call_type != "web_search_call": + continue + + action = _get_field(raw_call, "action") + sources = _get_field(action, "sources") if action else None + if not sources: + continue + + for source in sources: + url = getattr(source, "url", None) + if url is None and isinstance(source, Mapping): + url = source.get("url") + if url: + print(f"- {url}") print() print("### Final output ###") print() diff --git a/src/agents/__init__.py b/src/agents/__init__.py index b285d6f8c..c6d28aee0 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -13,7 +13,9 @@ ToolsToFinalOutputResult, ) from .agent_output import AgentOutputSchema, AgentOutputSchemaBase +from .apply_diff import apply_diff from .computer import AsyncComputer, Button, Computer, Environment +from .editor import ApplyPatchEditor, ApplyPatchOperation, ApplyPatchResult from .exceptions import ( AgentsException, InputGuardrailTripwireTriggered, @@ -48,7 +50,12 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks -from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession +from .memory import ( + OpenAIConversationsSession, + Session, + SessionABC, + SQLiteSession, +) from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.multi_provider import MultiProvider @@ -67,6 +74,7 @@ StreamEvent, ) from .tool import ( + ApplyPatchTool, CodeInterpreterTool, ComputerTool, FileSearchTool, @@ -80,6 +88,14 @@ MCPToolApprovalFunction, MCPToolApprovalFunctionResult, MCPToolApprovalRequest, + ShellActionRequest, + ShellCallData, + ShellCallOutcome, + ShellCommandOutput, + ShellCommandRequest, + ShellExecutor, + ShellResult, + ShellTool, Tool, ToolOutputFileContent, ToolOutputFileContentDict, @@ -192,6 +208,7 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", + "apply_diff", "run_demo_loop", "Model", "ModelProvider", @@ -273,6 +290,18 @@ def enable_verbose_stdout_logging(): "LocalShellCommandRequest", "LocalShellExecutor", "LocalShellTool", + "ShellActionRequest", + "ShellCallData", + "ShellCallOutcome", + "ShellCommandOutput", + "ShellCommandRequest", + "ShellExecutor", + "ShellResult", + "ShellTool", + "ApplyPatchEditor", + "ApplyPatchOperation", + "ApplyPatchResult", + "ApplyPatchTool", "Tool", "WebSearchTool", "HostedMCPTool", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 88a770a56..d3bd74f9d 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -3,12 +3,14 @@ import asyncio import dataclasses import inspect -from collections.abc import Awaitable +import json +from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from openai.types.responses import ( ResponseComputerToolCall, + ResponseCustomToolCall, ResponseFileSearchToolCall, ResponseFunctionToolCall, ResponseFunctionWebSearch, @@ -44,6 +46,7 @@ from .agent import Agent, ToolsToFinalOutputResult from .agent_output import AgentOutputSchemaBase from .computer import AsyncComputer, Computer +from .editor import ApplyPatchOperation, ApplyPatchResult from .exceptions import ( AgentsException, ModelBehaviorError, @@ -75,6 +78,7 @@ from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent from .tool import ( + ApplyPatchTool, ComputerTool, ComputerToolSafetyCheckData, FunctionTool, @@ -83,6 +87,13 @@ LocalShellCommandRequest, LocalShellTool, MCPToolApprovalRequest, + ShellActionRequest, + ShellCallData, + ShellCallOutcome, + ShellCommandOutput, + ShellCommandRequest, + ShellResult, + ShellTool, Tool, ) from .tool_context import ToolContext @@ -163,6 +174,18 @@ class ToolRunLocalShellCall: local_shell_tool: LocalShellTool +@dataclass +class ToolRunShellCall: + tool_call: Any + shell_tool: ShellTool + + +@dataclass +class ToolRunApplyPatchCall: + tool_call: Any + apply_patch_tool: ApplyPatchTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] @@ -170,6 +193,8 @@ class ProcessedResponse: functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] local_shell_calls: list[ToolRunLocalShellCall] + shell_calls: list[ToolRunShellCall] + apply_patch_calls: list[ToolRunApplyPatchCall] tools_used: list[str] # Names of all tools used, including hosted tools mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks @@ -182,6 +207,8 @@ def has_tools_or_approvals_to_run(self) -> bool: self.functions, self.computer_actions, self.local_shell_calls, + self.shell_calls, + self.apply_patch_calls, self.mcp_approval_requests, ] ) @@ -267,10 +294,13 @@ async def execute_tools_and_side_effects( new_step_items: list[RunItem] = [] new_step_items.extend(processed_response.new_items) - # First, lets run the tool calls - function tools, computer actions, and local shell calls + # First, run function tools, computer actions, shell calls, apply_patch calls, + # and legacy local shell calls. ( (function_results, tool_input_guardrail_results, tool_output_guardrail_results), computer_results, + shell_results, + apply_patch_results, local_shell_results, ) = await asyncio.gather( cls.execute_function_tool_calls( @@ -287,6 +317,20 @@ async def execute_tools_and_side_effects( context_wrapper=context_wrapper, config=run_config, ), + cls.execute_shell_calls( + agent=agent, + calls=processed_response.shell_calls, + hooks=hooks, + context_wrapper=context_wrapper, + config=run_config, + ), + cls.execute_apply_patch_calls( + agent=agent, + calls=processed_response.apply_patch_calls, + hooks=hooks, + context_wrapper=context_wrapper, + config=run_config, + ), cls.execute_local_shell_calls( agent=agent, calls=processed_response.local_shell_calls, @@ -297,6 +341,8 @@ async def execute_tools_and_side_effects( ) new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) + new_step_items.extend(shell_results) + new_step_items.extend(apply_patch_results) new_step_items.extend(local_shell_results) # Next, run the MCP approval requests @@ -431,6 +477,8 @@ def process_model_response( functions = [] computer_actions = [] local_shell_calls = [] + shell_calls = [] + apply_patch_calls = [] mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} @@ -439,6 +487,10 @@ def process_model_response( local_shell_tool = next( (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None ) + shell_tool = next((tool for tool in all_tools if isinstance(tool, ShellTool)), None) + apply_patch_tool = next( + (tool for tool in all_tools if isinstance(tool, ApplyPatchTool)), None + ) hosted_mcp_server_map = { tool.tool_config["server_label"]: tool for tool in all_tools @@ -446,6 +498,56 @@ def process_model_response( } for output in response.output: + output_type = _get_mapping_or_attr(output, "type") + logger.debug( + "Processing output item type=%s class=%s", + output_type, + output.__class__.__name__ if hasattr(output, "__class__") else type(output), + ) + if output_type == "shell_call": + items.append(ToolCallItem(raw_item=cast(Any, output), agent=agent)) + if not shell_tool: + tools_used.append("shell") + _error_tracing.attach_error_to_current_span( + SpanError( + message="Shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError("Model produced shell call without a shell tool.") + tools_used.append(shell_tool.name) + call_identifier = _get_mapping_or_attr(output, "call_id") or _get_mapping_or_attr( + output, "callId" + ) + logger.debug("Queuing shell_call %s", call_identifier) + shell_calls.append(ToolRunShellCall(tool_call=output, shell_tool=shell_tool)) + continue + if output_type == "apply_patch_call": + items.append(ToolCallItem(raw_item=cast(Any, output), agent=agent)) + if apply_patch_tool: + tools_used.append(apply_patch_tool.name) + call_identifier = _get_mapping_or_attr(output, "call_id") + if not call_identifier: + call_identifier = _get_mapping_or_attr(output, "callId") + logger.debug("Queuing apply_patch_call %s", call_identifier) + apply_patch_calls.append( + ToolRunApplyPatchCall( + tool_call=output, + apply_patch_tool=apply_patch_tool, + ) + ) + else: + tools_used.append("apply_patch") + _error_tracing.attach_error_to_current_span( + SpanError( + message="Apply patch tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced apply_patch call without an apply_patch tool." + ) + continue if isinstance(output, ResponseOutputMessage): items.append(MessageOutputItem(raw_item=output, agent=agent)) elif isinstance(output, ResponseFileSearchToolCall): @@ -508,20 +610,84 @@ def process_model_response( tools_used.append("code_interpreter") elif isinstance(output, LocalShellCall): items.append(ToolCallItem(raw_item=output, agent=agent)) - tools_used.append("local_shell") - if not local_shell_tool: + if shell_tool: + tools_used.append(shell_tool.name) + shell_calls.append(ToolRunShellCall(tool_call=output, shell_tool=shell_tool)) + else: + tools_used.append("local_shell") + if not local_shell_tool: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Local shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced local shell call without a local shell tool." + ) + local_shell_calls.append( + ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) + ) + elif isinstance(output, ResponseCustomToolCall) and _is_apply_patch_name( + output.name, apply_patch_tool + ): + parsed_operation = _parse_apply_patch_custom_input(output.input) + pseudo_call = { + "type": "apply_patch_call", + "call_id": output.call_id, + "operation": parsed_operation, + } + items.append(ToolCallItem(raw_item=cast(Any, pseudo_call), agent=agent)) + if apply_patch_tool: + tools_used.append(apply_patch_tool.name) + apply_patch_calls.append( + ToolRunApplyPatchCall( + tool_call=pseudo_call, + apply_patch_tool=apply_patch_tool, + ) + ) + else: + tools_used.append("apply_patch") _error_tracing.attach_error_to_current_span( SpanError( - message="Local shell tool not found", + message="Apply patch tool not found", data={}, ) ) raise ModelBehaviorError( - "Model produced local shell call without a local shell tool." + "Model produced apply_patch call without an apply_patch tool." ) - local_shell_calls.append( - ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) - ) + elif ( + isinstance(output, ResponseFunctionToolCall) + and _is_apply_patch_name(output.name, apply_patch_tool) + and output.name not in function_map + ): + parsed_operation = _parse_apply_patch_function_args(output.arguments) + pseudo_call = { + "type": "apply_patch_call", + "call_id": output.call_id, + "operation": parsed_operation, + } + items.append(ToolCallItem(raw_item=cast(Any, pseudo_call), agent=agent)) + if apply_patch_tool: + tools_used.append(apply_patch_tool.name) + apply_patch_calls.append( + ToolRunApplyPatchCall( + tool_call=pseudo_call, apply_patch_tool=apply_patch_tool + ) + ) + else: + tools_used.append("apply_patch") + _error_tracing.attach_error_to_current_span( + SpanError( + message="Apply patch tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced apply_patch call without an apply_patch tool." + ) + continue elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") @@ -581,6 +747,8 @@ def process_model_response( functions=functions, computer_actions=computer_actions, local_shell_calls=local_shell_calls, + shell_calls=shell_calls, + apply_patch_calls=apply_patch_calls, tools_used=tools_used, mcp_approval_requests=mcp_approval_requests, ) @@ -865,6 +1033,52 @@ async def execute_local_shell_calls( ) return results + @classmethod + async def execute_shell_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunShellCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + for call in calls: + results.append( + await ShellAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + + @classmethod + async def execute_apply_patch_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunApplyPatchCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + for call in calls: + results.append( + await ApplyPatchAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + @classmethod async def execute_computer_actions( cls, @@ -1423,18 +1637,488 @@ async def execute( ), ) + raw_payload: dict[str, Any] = { + "type": "local_shell_call_output", + "call_id": call.tool_call.call_id, + "output": result, + } return ToolCallOutputItem( agent=agent, output=result, - # LocalShellCallOutput type uses the field name "id", but the server wants "call_id". - # raw_item keeps the upstream type, so we ignore the type checker here. - raw_item={ # type: ignore[misc, arg-type] - "type": "local_shell_call_output", - "call_id": call.tool_call.call_id, - "output": result, - }, + raw_item=raw_payload, + ) + + +class ShellAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunShellCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + shell_call = _coerce_shell_call(call.tool_call) + request = ShellCommandRequest(ctx_wrapper=context_wrapper, data=shell_call) + status: Literal["completed", "failed"] = "completed" + output_text = "" + shell_output_payload: list[dict[str, Any]] | None = None + provider_meta: dict[str, Any] | None = None + max_output_length: int | None = None + + try: + executor_result = call.shell_tool.executor(request) + result = ( + await executor_result if inspect.isawaitable(executor_result) else executor_result + ) + + if isinstance(result, ShellResult): + normalized = [_normalize_shell_output(entry) for entry in result.output] + output_text = _render_shell_outputs(normalized) + shell_output_payload = [_serialize_shell_output(entry) for entry in normalized] + provider_meta = dict(result.provider_data or {}) + max_output_length = result.max_output_length + else: + output_text = str(result) + except Exception as exc: + status = "failed" + output_text = _format_shell_error(exc) + logger.error("Shell executor failed: %s", exc, exc_info=True) + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + raw_entries: list[dict[str, Any]] | None = None + if shell_output_payload: + raw_entries = shell_output_payload + elif output_text: + raw_entries = [ + { + "stdout": output_text, + "stderr": "", + "status": status, + "outcome": "success" if status == "completed" else "failure", + } + ] + + structured_output: list[dict[str, Any]] = [] + if raw_entries: + for entry in raw_entries: + sanitized = dict(entry) + status_value = sanitized.pop("status", None) + sanitized.pop("provider_data", None) + raw_exit_code = sanitized.pop("exit_code", None) + sanitized.pop("command", None) + outcome_value = sanitized.get("outcome") + if isinstance(outcome_value, str): + resolved_type = "exit" + if status_value == "timeout": + resolved_type = "timeout" + outcome_payload: dict[str, Any] = {"type": resolved_type} + if resolved_type == "exit": + outcome_payload["exit_code"] = _resolve_exit_code( + raw_exit_code, outcome_value + ) + sanitized["outcome"] = outcome_payload + elif isinstance(outcome_value, Mapping): + outcome_payload = dict(outcome_value) + outcome_status = cast(Optional[str], outcome_payload.pop("status", None)) + outcome_type = outcome_payload.get("type") + if outcome_type != "timeout": + outcome_payload.setdefault( + "exit_code", + _resolve_exit_code( + raw_exit_code, + outcome_status if isinstance(outcome_status, str) else None, + ), + ) + sanitized["outcome"] = outcome_payload + structured_output.append(sanitized) + + raw_item: dict[str, Any] = { + "type": "shell_call_output", + "call_id": shell_call.call_id, + "output": structured_output, + "status": status, + } + if max_output_length is not None: + raw_item["max_output_length"] = max_output_length + if raw_entries: + raw_item["shell_output"] = raw_entries + if provider_meta: + raw_item["provider_data"] = provider_meta + + return ToolCallOutputItem( + agent=agent, + output=output_text, + raw_item=cast(Any, raw_item), + ) + + +class ApplyPatchAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunApplyPatchCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + apply_patch_tool = call.apply_patch_tool + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) + if agent.hooks + else _coro.noop_coroutine() + ), ) + status: Literal["completed", "failed"] = "completed" + output_text = "" + + try: + operation = _coerce_apply_patch_operation(call.tool_call) + editor = apply_patch_tool.editor + if operation.type == "create_file": + result = editor.create_file(operation) + elif operation.type == "update_file": + result = editor.update_file(operation) + elif operation.type == "delete_file": + result = editor.delete_file(operation) + else: # pragma: no cover - validated in _coerce_apply_patch_operation + raise ModelBehaviorError(f"Unsupported apply_patch operation: {operation.type}") + + awaited = await result if inspect.isawaitable(result) else result + normalized = _normalize_apply_patch_result(awaited) + if normalized: + if normalized.status in {"completed", "failed"}: + status = normalized.status + if normalized.output: + output_text = normalized.output + except Exception as exc: + status = "failed" + output_text = _format_shell_error(exc) + logger.error("Apply patch editor failed: %s", exc, exc_info=True) + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), + ( + agent.hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + raw_item: dict[str, Any] = { + "type": "apply_patch_call_output", + "call_id": _extract_apply_patch_call_id(call.tool_call), + "status": status, + } + if output_text: + raw_item["output"] = output_text + + return ToolCallOutputItem( + agent=agent, + output=output_text, + raw_item=cast(Any, raw_item), + ) + + +def _normalize_shell_output(entry: ShellCommandOutput | Mapping[str, Any]) -> ShellCommandOutput: + if isinstance(entry, ShellCommandOutput): + return entry + + stdout = str(entry.get("stdout", "") or "") + stderr = str(entry.get("stderr", "") or "") + command_value = entry.get("command") + provider_data_value = entry.get("provider_data") + outcome_value = entry.get("outcome") + + outcome_type: Literal["exit", "timeout"] = "exit" + exit_code_value: Any | None = None + + if isinstance(outcome_value, Mapping): + type_value = outcome_value.get("type") + if type_value == "timeout": + outcome_type = "timeout" + elif isinstance(type_value, str): + outcome_type = "exit" + exit_code_value = outcome_value.get("exit_code") or outcome_value.get("exitCode") + else: + status_str = str(entry.get("status", "completed") or "completed").lower() + if status_str == "timeout": + outcome_type = "timeout" + if isinstance(outcome_value, str): + if outcome_value == "failure": + exit_code_value = 1 + elif outcome_value == "success": + exit_code_value = 0 + exit_code_value = exit_code_value or entry.get("exit_code") or entry.get("exitCode") + + outcome = ShellCallOutcome( + type=outcome_type, + exit_code=_normalize_exit_code(exit_code_value), + ) + + return ShellCommandOutput( + stdout=stdout, + stderr=stderr, + outcome=outcome, + command=str(command_value) if command_value is not None else None, + provider_data=cast(dict[str, Any], provider_data_value) + if isinstance(provider_data_value, Mapping) + else provider_data_value, + ) + + +def _serialize_shell_output(output: ShellCommandOutput) -> dict[str, Any]: + payload: dict[str, Any] = { + "stdout": output.stdout, + "stderr": output.stderr, + "status": output.status, + "outcome": {"type": output.outcome.type}, + } + if output.outcome.type == "exit": + payload["outcome"]["exit_code"] = output.outcome.exit_code + if output.outcome.exit_code is not None: + payload["exit_code"] = output.outcome.exit_code + if output.command is not None: + payload["command"] = output.command + if output.provider_data: + payload["provider_data"] = output.provider_data + return payload + + +def _resolve_exit_code(raw_exit_code: Any, outcome_status: str | None) -> int: + normalized = _normalize_exit_code(raw_exit_code) + if normalized is not None: + return normalized + + normalized_status = (outcome_status or "").lower() + if normalized_status == "success": + return 0 + if normalized_status == "failure": + return 1 + return 0 + + +def _normalize_exit_code(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _render_shell_outputs(outputs: Sequence[ShellCommandOutput]) -> str: + if not outputs: + return "(no output)" + + rendered_chunks: list[str] = [] + for result in outputs: + chunk_lines: list[str] = [] + if result.command: + chunk_lines.append(f"$ {result.command}") + + stdout = result.stdout.rstrip("\n") + stderr = result.stderr.rstrip("\n") + + if stdout: + chunk_lines.append(stdout) + if stderr: + if stdout: + chunk_lines.append("") + chunk_lines.append("stderr:") + chunk_lines.append(stderr) + + if result.exit_code not in (None, 0): + chunk_lines.append(f"exit code: {result.exit_code}") + if result.status == "timeout": + chunk_lines.append("status: timeout") + + chunk = "\n".join(chunk_lines).strip() + rendered_chunks.append(chunk if chunk else "(no output)") + + return "\n\n".join(rendered_chunks) + + +def _format_shell_error(error: Exception | BaseException | Any) -> str: + if isinstance(error, Exception): + message = str(error) + return message or error.__class__.__name__ + try: + return str(error) + except Exception: # pragma: no cover - fallback only + return repr(error) + + +def _get_mapping_or_attr(target: Any, key: str) -> Any: + if isinstance(target, Mapping): + return target.get(key) + return getattr(target, key, None) + + +def _extract_shell_call_id(tool_call: Any) -> str: + value = _get_mapping_or_attr(tool_call, "call_id") + if not value: + value = _get_mapping_or_attr(tool_call, "callId") + if not value: + raise ModelBehaviorError("Shell call is missing call_id.") + return str(value) + + +def _coerce_shell_call(tool_call: Any) -> ShellCallData: + call_id = _extract_shell_call_id(tool_call) + action_payload = _get_mapping_or_attr(tool_call, "action") + if action_payload is None: + raise ModelBehaviorError("Shell call is missing an action payload.") + + commands_value = _get_mapping_or_attr(action_payload, "commands") + if not isinstance(commands_value, Sequence): + raise ModelBehaviorError("Shell call action is missing commands.") + commands: list[str] = [] + for entry in commands_value: + if entry is None: + continue + commands.append(str(entry)) + if not commands: + raise ModelBehaviorError("Shell call action must include at least one command.") + + timeout_value = ( + _get_mapping_or_attr(action_payload, "timeout_ms") + or _get_mapping_or_attr(action_payload, "timeoutMs") + or _get_mapping_or_attr(action_payload, "timeout") + ) + timeout_ms = int(timeout_value) if isinstance(timeout_value, (int, float)) else None + + max_length_value = _get_mapping_or_attr( + action_payload, "max_output_length" + ) or _get_mapping_or_attr(action_payload, "maxOutputLength") + max_output_length = ( + int(max_length_value) if isinstance(max_length_value, (int, float)) else None + ) + + action = ShellActionRequest( + commands=commands, + timeout_ms=timeout_ms, + max_output_length=max_output_length, + ) + + status_value = _get_mapping_or_attr(tool_call, "status") + status_literal: Literal["in_progress", "completed"] | None = None + if isinstance(status_value, str): + lowered = status_value.lower() + if lowered in {"in_progress", "completed"}: + status_literal = cast(Literal["in_progress", "completed"], lowered) + + return ShellCallData(call_id=call_id, action=action, status=status_literal, raw=tool_call) + + +def _parse_apply_patch_custom_input(input_json: str) -> dict[str, Any]: + try: + parsed = json.loads(input_json or "{}") + except json.JSONDecodeError as exc: + raise ModelBehaviorError(f"Invalid apply_patch input JSON: {exc}") from exc + if not isinstance(parsed, Mapping): + raise ModelBehaviorError("Apply patch input must be a JSON object.") + return dict(parsed) + + +def _parse_apply_patch_function_args(arguments: str) -> dict[str, Any]: + try: + parsed = json.loads(arguments or "{}") + except json.JSONDecodeError as exc: + raise ModelBehaviorError(f"Invalid apply_patch arguments JSON: {exc}") from exc + if not isinstance(parsed, Mapping): + raise ModelBehaviorError("Apply patch arguments must be a JSON object.") + return dict(parsed) + + +def _extract_apply_patch_call_id(tool_call: Any) -> str: + value = _get_mapping_or_attr(tool_call, "call_id") + if not value: + value = _get_mapping_or_attr(tool_call, "callId") + if not value: + raise ModelBehaviorError("Apply patch call is missing call_id.") + return str(value) + + +def _coerce_apply_patch_operation(tool_call: Any) -> ApplyPatchOperation: + raw_operation = _get_mapping_or_attr(tool_call, "operation") + if raw_operation is None: + raise ModelBehaviorError("Apply patch call is missing an operation payload.") + + op_type_value = str(_get_mapping_or_attr(raw_operation, "type")) + if op_type_value not in {"create_file", "update_file", "delete_file"}: + raise ModelBehaviorError(f"Unknown apply_patch operation: {op_type_value}") + op_type_literal = cast(Literal["create_file", "update_file", "delete_file"], op_type_value) + + path = _get_mapping_or_attr(raw_operation, "path") + if not isinstance(path, str) or not path: + raise ModelBehaviorError("Apply patch operation is missing a valid path.") + + diff_value = _get_mapping_or_attr(raw_operation, "diff") + if op_type_literal in {"create_file", "update_file"}: + if not isinstance(diff_value, str) or not diff_value: + raise ModelBehaviorError( + f"Apply patch operation {op_type_literal} is missing the required diff payload." + ) + diff: str | None = diff_value + else: + diff = None + + return ApplyPatchOperation(type=op_type_literal, path=str(path), diff=diff) + + +def _normalize_apply_patch_result( + result: ApplyPatchResult | Mapping[str, Any] | str | None, +) -> ApplyPatchResult | None: + if result is None: + return None + if isinstance(result, ApplyPatchResult): + return result + if isinstance(result, Mapping): + status = result.get("status") + output = result.get("output") + normalized_status = status if status in {"completed", "failed"} else None + normalized_output = str(output) if output is not None else None + return ApplyPatchResult(status=normalized_status, output=normalized_output) + if isinstance(result, str): + return ApplyPatchResult(output=result) + return ApplyPatchResult(output=str(result)) + + +def _is_apply_patch_name(name: str | None, tool: ApplyPatchTool | None) -> bool: + if not name: + return False + candidate = name.strip().lower() + if candidate.startswith("apply_patch"): + return True + if tool and candidate == tool.name.strip().lower(): + return True + return False + def _build_litellm_json_tool_call(output: ResponseFunctionToolCall) -> FunctionTool: async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any: diff --git a/src/agents/apply_diff.py b/src/agents/apply_diff.py new file mode 100644 index 000000000..e1606e359 --- /dev/null +++ b/src/agents/apply_diff.py @@ -0,0 +1,329 @@ +"""Utility for applying V4A diffs against text inputs.""" + +from __future__ import annotations + +import re +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Callable, Literal + +ApplyDiffMode = Literal["default", "create"] + + +@dataclass +class Chunk: + orig_index: int + del_lines: list[str] + ins_lines: list[str] + + +@dataclass +class ParserState: + lines: list[str] + index: int = 0 + fuzz: int = 0 + + +@dataclass +class ParsedUpdateDiff: + chunks: list[Chunk] + fuzz: int + + +@dataclass +class ReadSectionResult: + next_context: list[str] + section_chunks: list[Chunk] + end_index: int + eof: bool + + +END_PATCH = "*** End Patch" +END_FILE = "*** End of File" +SECTION_TERMINATORS = [ + END_PATCH, + "*** Update File:", + "*** Delete File:", + "*** Add File:", +] +END_SECTION_MARKERS = [*SECTION_TERMINATORS, END_FILE] + + +def apply_diff(input: str, diff: str, mode: ApplyDiffMode = "default") -> str: + """Apply a V4A diff to the provided text. + + This parser understands both the create-file syntax (only "+" prefixed + lines) and the default update syntax that includes context hunks. + """ + + diff_lines = _normalize_diff_lines(diff) + if mode == "create": + return _parse_create_diff(diff_lines) + + parsed = _parse_update_diff(diff_lines, input) + return _apply_chunks(input, parsed.chunks) + + +def _normalize_diff_lines(diff: str) -> list[str]: + lines = [line.rstrip("\r") for line in re.split(r"\r?\n", diff)] + if lines and lines[-1] == "": + lines.pop() + return lines + + +def _is_done(state: ParserState, prefixes: Sequence[str]) -> bool: + if state.index >= len(state.lines): + return True + if any(state.lines[state.index].startswith(prefix) for prefix in prefixes): + return True + return False + + +def _read_str(state: ParserState, prefix: str) -> str: + if state.index >= len(state.lines): + return "" + current = state.lines[state.index] + if current.startswith(prefix): + state.index += 1 + return current[len(prefix) :] + return "" + + +def _parse_create_diff(lines: list[str]) -> str: + parser = ParserState(lines=[*lines, END_PATCH]) + output: list[str] = [] + + while not _is_done(parser, SECTION_TERMINATORS): + if parser.index >= len(parser.lines): + break + line = parser.lines[parser.index] + parser.index += 1 + if not line.startswith("+"): + raise ValueError(f"Invalid Add File Line: {line}") + output.append(line[1:]) + + return "\n".join(output) + + +def _parse_update_diff(lines: list[str], input: str) -> ParsedUpdateDiff: + parser = ParserState(lines=[*lines, END_PATCH]) + input_lines = input.split("\n") + chunks: list[Chunk] = [] + cursor = 0 + + while not _is_done(parser, END_SECTION_MARKERS): + anchor = _read_str(parser, "@@ ") + has_bare_anchor = ( + anchor == "" and parser.index < len(parser.lines) and parser.lines[parser.index] == "@@" + ) + if has_bare_anchor: + parser.index += 1 + + if not (anchor or has_bare_anchor or cursor == 0): + current_line = parser.lines[parser.index] if parser.index < len(parser.lines) else "" + raise ValueError(f"Invalid Line:\n{current_line}") + + if anchor.strip(): + cursor = _advance_cursor_to_anchor(anchor, input_lines, cursor, parser) + + section = _read_section(parser.lines, parser.index) + find_result = _find_context(input_lines, section.next_context, cursor, section.eof) + if find_result.new_index == -1: + ctx_text = "\n".join(section.next_context) + if section.eof: + raise ValueError(f"Invalid EOF Context {cursor}:\n{ctx_text}") + raise ValueError(f"Invalid Context {cursor}:\n{ctx_text}") + + cursor = find_result.new_index + len(section.next_context) + parser.fuzz += find_result.fuzz + parser.index = section.end_index + + for ch in section.section_chunks: + chunks.append( + Chunk( + orig_index=ch.orig_index + find_result.new_index, + del_lines=list(ch.del_lines), + ins_lines=list(ch.ins_lines), + ) + ) + + return ParsedUpdateDiff(chunks=chunks, fuzz=parser.fuzz) + + +def _advance_cursor_to_anchor( + anchor: str, + input_lines: list[str], + cursor: int, + parser: ParserState, +) -> int: + found = False + + if not any(line == anchor for line in input_lines[:cursor]): + for i in range(cursor, len(input_lines)): + if input_lines[i] == anchor: + cursor = i + 1 + found = True + break + + if not found and not any(line.strip() == anchor.strip() for line in input_lines[:cursor]): + for i in range(cursor, len(input_lines)): + if input_lines[i].strip() == anchor.strip(): + cursor = i + 1 + parser.fuzz += 1 + found = True + break + + return cursor + + +def _read_section(lines: list[str], start_index: int) -> ReadSectionResult: + context: list[str] = [] + del_lines: list[str] = [] + ins_lines: list[str] = [] + section_chunks: list[Chunk] = [] + mode: Literal["keep", "add", "delete"] = "keep" + index = start_index + orig_index = index + + while index < len(lines): + raw = lines[index] + if ( + raw.startswith("@@") + or raw.startswith(END_PATCH) + or raw.startswith("*** Update File:") + or raw.startswith("*** Delete File:") + or raw.startswith("*** Add File:") + or raw.startswith(END_FILE) + ): + break + if raw == "***": + break + if raw.startswith("***"): + raise ValueError(f"Invalid Line: {raw}") + + index += 1 + last_mode = mode + line = raw if raw else " " + prefix = line[0] + if prefix == "+": + mode = "add" + elif prefix == "-": + mode = "delete" + elif prefix == " ": + mode = "keep" + else: + raise ValueError(f"Invalid Line: {line}") + + line_content = line[1:] + switching_to_context = mode == "keep" and last_mode != mode + if switching_to_context and (del_lines or ins_lines): + section_chunks.append( + Chunk( + orig_index=len(context) - len(del_lines), + del_lines=list(del_lines), + ins_lines=list(ins_lines), + ) + ) + del_lines = [] + ins_lines = [] + + if mode == "delete": + del_lines.append(line_content) + context.append(line_content) + elif mode == "add": + ins_lines.append(line_content) + else: + context.append(line_content) + + if del_lines or ins_lines: + section_chunks.append( + Chunk( + orig_index=len(context) - len(del_lines), + del_lines=list(del_lines), + ins_lines=list(ins_lines), + ) + ) + + if index < len(lines) and lines[index] == END_FILE: + return ReadSectionResult(context, section_chunks, index + 1, True) + + if index == orig_index: + next_line = lines[index] if index < len(lines) else "" + raise ValueError(f"Nothing in this section - index={index} {next_line}") + + return ReadSectionResult(context, section_chunks, index, False) + + +@dataclass +class ContextMatch: + new_index: int + fuzz: int + + +def _find_context(lines: list[str], context: list[str], start: int, eof: bool) -> ContextMatch: + if eof: + end_start = max(0, len(lines) - len(context)) + end_match = _find_context_core(lines, context, end_start) + if end_match.new_index != -1: + return end_match + fallback = _find_context_core(lines, context, start) + return ContextMatch(new_index=fallback.new_index, fuzz=fallback.fuzz + 10000) + return _find_context_core(lines, context, start) + + +def _find_context_core(lines: list[str], context: list[str], start: int) -> ContextMatch: + if not context: + return ContextMatch(new_index=start, fuzz=0) + + for i in range(start, len(lines)): + if _equals_slice(lines, context, i, lambda value: value): + return ContextMatch(new_index=i, fuzz=0) + for i in range(start, len(lines)): + if _equals_slice(lines, context, i, lambda value: value.rstrip()): + return ContextMatch(new_index=i, fuzz=1) + for i in range(start, len(lines)): + if _equals_slice(lines, context, i, lambda value: value.strip()): + return ContextMatch(new_index=i, fuzz=100) + + return ContextMatch(new_index=-1, fuzz=0) + + +def _equals_slice( + source: list[str], target: list[str], start: int, map_fn: Callable[[str], str] +) -> bool: + if start + len(target) > len(source): + return False + for offset, target_value in enumerate(target): + if map_fn(source[start + offset]) != map_fn(target_value): + return False + return True + + +def _apply_chunks(input: str, chunks: list[Chunk]) -> str: + orig_lines = input.split("\n") + dest_lines: list[str] = [] + cursor = 0 + + for chunk in chunks: + if chunk.orig_index > len(orig_lines): + raise ValueError( + f"applyDiff: chunk.origIndex {chunk.orig_index} > input length {len(orig_lines)}" + ) + if cursor > chunk.orig_index: + raise ValueError( + f"applyDiff: overlapping chunk at {chunk.orig_index} (cursor {cursor})" + ) + + dest_lines.extend(orig_lines[cursor : chunk.orig_index]) + cursor = chunk.orig_index + + if chunk.ins_lines: + dest_lines.extend(chunk.ins_lines) + + cursor += len(chunk.del_lines) + + dest_lines.extend(orig_lines[cursor:]) + return "\n".join(dest_lines) + + +__all__ = ["apply_diff"] diff --git a/src/agents/editor.py b/src/agents/editor.py new file mode 100644 index 000000000..38dd616b3 --- /dev/null +++ b/src/agents/editor.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import Literal, Protocol, runtime_checkable + +from .util._types import MaybeAwaitable + +ApplyPatchOperationType = Literal["create_file", "update_file", "delete_file"] + +_DATACLASS_KWARGS = {"slots": True} if sys.version_info >= (3, 10) else {} + + +@dataclass(**_DATACLASS_KWARGS) +class ApplyPatchOperation: + """Represents a single apply_patch editor operation requested by the model.""" + + type: ApplyPatchOperationType + path: str + diff: str | None = None + + +@dataclass(**_DATACLASS_KWARGS) +class ApplyPatchResult: + """Optional metadata returned by editor operations.""" + + status: Literal["completed", "failed"] | None = None + output: str | None = None + + +@runtime_checkable +class ApplyPatchEditor(Protocol): + """Host-defined editor that applies diffs on disk.""" + + def create_file( + self, operation: ApplyPatchOperation + ) -> MaybeAwaitable[ApplyPatchResult | str | None]: ... + + def update_file( + self, operation: ApplyPatchOperation + ) -> MaybeAwaitable[ApplyPatchResult | str | None]: ... + + def delete_file( + self, operation: ApplyPatchOperation + ) -> MaybeAwaitable[ApplyPatchResult | str | None]: ... diff --git a/src/agents/items.py b/src/agents/items.py index 8e7d1cfc3..24defb22d 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -2,7 +2,7 @@ import abc from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast import pydantic from openai.types.responses import ( @@ -141,12 +141,13 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseCodeInterpreterToolCall, ImageGenerationCall, LocalShellCall, + dict[str, Any], ] """A type that represents a tool call item.""" @dataclass -class ToolCallItem(RunItemBase[ToolCallItemTypes]): +class ToolCallItem(RunItemBase[Any]): """Represents a tool call e.g. a function call or computer action call.""" raw_item: ToolCallItemTypes @@ -155,13 +156,19 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]): type: Literal["tool_call_item"] = "tool_call_item" +ToolCallOutputTypes: TypeAlias = Union[ + FunctionCallOutput, + ComputerCallOutput, + LocalShellCallOutput, + dict[str, Any], +] + + @dataclass -class ToolCallOutputItem( - RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]] -): +class ToolCallOutputItem(RunItemBase[Any]): """Represents the output of a tool call.""" - raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput + raw_item: ToolCallOutputTypes """The raw item from the model.""" output: Any @@ -171,6 +178,25 @@ class ToolCallOutputItem( type: Literal["tool_call_output_item"] = "tool_call_output_item" + def to_input_item(self) -> TResponseInputItem: + """Converts the tool output into an input item for the next model turn. + + Hosted tool outputs (e.g. shell/apply_patch) carry a `status` field for the SDK's + book-keeping, but the Responses API does not yet accept that parameter. Strip it from the + payload we send back to the model while keeping the original raw item intact. + """ + + if isinstance(self.raw_item, dict): + payload = dict(self.raw_item) + payload_type = payload.get("type") + if payload_type == "shell_call_output": + payload.pop("status", None) + payload.pop("shell_output", None) + payload.pop("provider_data", None) + return cast(TResponseInputItem, payload) + + return super().to_input_item() + @dataclass class ReasoningItem(RunItemBase[ResponseReasoningItem]): diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 36a981404..466496b01 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -27,6 +27,7 @@ from ..logger import logger from ..model_settings import MCPToolChoice from ..tool import ( + ApplyPatchTool, CodeInterpreterTool, ComputerTool, FileSearchTool, @@ -34,6 +35,7 @@ HostedMCPTool, ImageGenerationTool, LocalShellTool, + ShellTool, Tool, WebSearchTool, ) @@ -489,6 +491,12 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None elif isinstance(tool, HostedMCPTool): converted_tool = tool.tool_config includes = None + elif isinstance(tool, ApplyPatchTool): + converted_tool = cast(ToolParam, {"type": "apply_patch"}) + includes = None + elif isinstance(tool, ShellTool): + converted_tool = cast(ToolParam, {"type": "shell"}) + includes = None elif isinstance(tool, ImageGenerationTool): converted_tool = tool.tool_config includes = None diff --git a/src/agents/run.py b/src/agents/run.py index 5b25df4f2..c14f13e3f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -6,7 +6,7 @@ import os import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast, get_args +from typing import Any, Callable, Generic, cast, get_args, get_origin from openai.types.responses import ( ResponseCompletedEvent, @@ -1886,7 +1886,19 @@ async def _input_guardrail_tripwire_triggered_for_stream( DEFAULT_AGENT_RUNNER = AgentRunner() -_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) + + +def _get_tool_call_types() -> tuple[type, ...]: + normalized_types: list[type] = [] + for type_hint in get_args(ToolCallItemTypes): + origin = get_origin(type_hint) + candidate = origin or type_hint + if isinstance(candidate, type): + normalized_types.append(candidate) + return tuple(normalized_types) + + +_TOOL_CALL_TYPES: tuple[type, ...] = _get_tool_call_types() def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: diff --git a/src/agents/tool.py b/src/agents/tool.py index 39db129b7..c3baa6ffc 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -3,7 +3,7 @@ import inspect import json from collections.abc import Awaitable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions @@ -20,6 +20,7 @@ from . import _debug from .computer import AsyncComputer, Computer +from .editor import ApplyPatchEditor from .exceptions import ModelBehaviorError from .function_schema import DocstringStyle, function_schema from .logger import logger @@ -373,12 +374,109 @@ def name(self): return "local_shell" +@dataclass +class ShellCallOutcome: + """Describes the terminal condition of a shell command.""" + + type: Literal["exit", "timeout"] + exit_code: int | None = None + + +def _default_shell_outcome() -> ShellCallOutcome: + return ShellCallOutcome(type="exit") + + +@dataclass +class ShellCommandOutput: + """Structured output for a single shell command execution.""" + + stdout: str = "" + stderr: str = "" + outcome: ShellCallOutcome = field(default_factory=_default_shell_outcome) + command: str | None = None + provider_data: dict[str, Any] | None = None + + @property + def exit_code(self) -> int | None: + return self.outcome.exit_code + + @property + def status(self) -> Literal["completed", "timeout"]: + return "timeout" if self.outcome.type == "timeout" else "completed" + + +@dataclass +class ShellResult: + """Result returned by a shell executor.""" + + output: list[ShellCommandOutput] + max_output_length: int | None = None + provider_data: dict[str, Any] | None = None + + +@dataclass +class ShellActionRequest: + """Action payload for a next-generation shell call.""" + + commands: list[str] + timeout_ms: int | None = None + max_output_length: int | None = None + + +@dataclass +class ShellCallData: + """Normalized shell call data provided to shell executors.""" + + call_id: str + action: ShellActionRequest + status: Literal["in_progress", "completed"] | None = None + raw: Any | None = None + + +@dataclass +class ShellCommandRequest: + """A request to execute a modern shell call.""" + + ctx_wrapper: RunContextWrapper[Any] + data: ShellCallData + + +ShellExecutor = Callable[[ShellCommandRequest], MaybeAwaitable[Union[str, ShellResult]]] +"""Executes a shell command sequence and returns either text or structured output.""" + + +@dataclass +class ShellTool: + """Next-generation shell tool. LocalShellTool will be deprecated in favor of this.""" + + executor: ShellExecutor + name: str = "shell" + + @property + def type(self) -> str: + return "shell" + + +@dataclass +class ApplyPatchTool: + """Hosted apply_patch tool. Lets the model request file mutations via unified diffs.""" + + editor: ApplyPatchEditor + name: str = "apply_patch" + + @property + def type(self) -> str: + return "apply_patch" + + Tool = Union[ FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool, + ShellTool, + ApplyPatchTool, LocalShellTool, ImageGenerationTool, CodeInterpreterTool, diff --git a/tests/extensions/memory/test_dapr_redis_integration.py b/tests/extensions/memory/test_dapr_redis_integration.py index 858ef1801..58d540c21 100644 --- a/tests/extensions/memory/test_dapr_redis_integration.py +++ b/tests/extensions/memory/test_dapr_redis_integration.py @@ -11,15 +11,32 @@ import asyncio import os +import shutil import tempfile import time import urllib.request +import docker # type: ignore[import-untyped] import pytest +from docker.errors import DockerException # type: ignore[import-untyped] # Skip tests if dependencies are not available pytest.importorskip("dapr") # Skip tests if Dapr is not installed pytest.importorskip("testcontainers") # Skip if testcontainers is not installed +if shutil.which("docker") is None: + pytest.skip( + "Docker executable is not available; skipping Dapr integration tests", + allow_module_level=True, + ) +try: + client = docker.from_env() + client.ping() +except DockerException: + pytest.skip( + "Docker daemon is not available; skipping Dapr integration tests", allow_module_level=True + ) +else: + client.close() from testcontainers.core.container import DockerContainer # type: ignore[import-untyped] from testcontainers.core.network import Network # type: ignore[import-untyped] diff --git a/tests/test_agents_logging.py b/tests/test_agents_logging.py new file mode 100644 index 000000000..c63fe3d0e --- /dev/null +++ b/tests/test_agents_logging.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from agents import enable_verbose_stdout_logging + + +def test_enable_verbose_stdout_logging_attaches_handler() -> None: + logger = logging.getLogger("openai.agents") + logger.handlers.clear() + enable_verbose_stdout_logging() + assert logger.handlers + logger.handlers.clear() diff --git a/tests/test_apply_diff.py b/tests/test_apply_diff.py new file mode 100644 index 000000000..edb5be99a --- /dev/null +++ b/tests/test_apply_diff.py @@ -0,0 +1,36 @@ +"""Tests for the V4A diff helper.""" + +from __future__ import annotations + +import pytest + +from agents import apply_diff + + +def test_apply_diff_with_floating_hunk_adds_lines() -> None: + diff = "\n".join(["@@", "+hello", "+world"]) # no trailing newline + assert apply_diff("", diff) == "hello\nworld\n" + + +def test_apply_diff_create_mode_requires_plus_prefix() -> None: + diff = "plain line" + with pytest.raises(ValueError): + apply_diff("", diff, mode="create") + + +def test_apply_diff_create_mode_perserves_trailing_newline() -> None: + diff = "\n".join(["+hello", "+world", "+"]) + assert apply_diff("", diff, mode="create") == "hello\nworld\n" + + +def test_apply_diff_applies_contextual_replacement() -> None: + input_text = "line1\nline2\nline3\n" + diff = "\n".join(["@@ line1", "-line2", "+updated", " line3"]) + assert apply_diff(input_text, diff) == "line1\nupdated\nline3\n" + + +def test_apply_diff_raises_on_context_mismatch() -> None: + input_text = "one\ntwo\n" + diff = "\n".join(["@@ -1,2 +1,2 @@", " x", "-two", "+2"]) + with pytest.raises(ValueError): + apply_diff(input_text, diff) diff --git a/tests/test_apply_diff_helpers.py b/tests/test_apply_diff_helpers.py new file mode 100644 index 000000000..12141f42b --- /dev/null +++ b/tests/test_apply_diff_helpers.py @@ -0,0 +1,73 @@ +"""Direct tests for the apply_diff helpers to exercise corner cases.""" + +from __future__ import annotations + +import pytest + +from agents.apply_diff import ( + Chunk, + ParserState, + _apply_chunks, + _find_context, + _find_context_core, + _is_done, + _normalize_diff_lines, + _read_section, + _read_str, +) + + +def test_normalize_diff_lines_drops_trailing_blank() -> None: + assert _normalize_diff_lines("a\nb\n") == ["a", "b"] + + +def test_is_done_true_when_index_out_of_range() -> None: + state = ParserState(lines=["line"], index=1) + assert _is_done(state, []) + + +def test_read_str_returns_empty_when_missing_prefix() -> None: + state = ParserState(lines=["value"], index=0) + assert _read_str(state, "nomatch") == "" + assert state.index == 0 + + +def test_read_section_returns_eof_flag() -> None: + result = _read_section(["*** End of File"], 0) + assert result.eof + + +def test_read_section_raises_on_invalid_marker() -> None: + with pytest.raises(ValueError): + _read_section(["*** Bad Marker"], 0) + + +def test_read_section_raises_when_empty_segment() -> None: + with pytest.raises(ValueError): + _read_section([], 0) + + +def test_find_context_eof_fallbacks() -> None: + match = _find_context(["one"], ["missing"], start=0, eof=True) + assert match.new_index == -1 + assert match.fuzz >= 10000 + + +def test_find_context_core_stripped_matches() -> None: + match = _find_context_core([" line "], ["line"], start=0) + assert match.new_index == 0 + assert match.fuzz == 100 + + +def test_apply_chunks_rejects_bad_chunks() -> None: + with pytest.raises(ValueError): + _apply_chunks("abc", [Chunk(orig_index=10, del_lines=[], ins_lines=[])]) + + with pytest.raises(ValueError): + _apply_chunks( + "abc", + [ + Chunk(orig_index=0, del_lines=["a"], ins_lines=[]), + Chunk(orig_index=0, del_lines=["b"], ins_lines=[]), + ], + ) diff --git a/tests/test_apply_patch_tool.py b/tests/test_apply_patch_tool.py new file mode 100644 index 000000000..197a7550f --- /dev/null +++ b/tests/test_apply_patch_tool.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import pytest + +from agents import Agent, ApplyPatchTool, RunConfig, RunContextWrapper, RunHooks +from agents._run_impl import ApplyPatchAction, ToolRunApplyPatchCall +from agents.editor import ApplyPatchOperation, ApplyPatchResult +from agents.items import ToolCallOutputItem + + +@dataclass +class DummyApplyPatchCall: + type: str + call_id: str + operation: dict[str, Any] + + +class RecordingEditor: + def __init__(self) -> None: + self.operations: list[ApplyPatchOperation] = [] + + def create_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + self.operations.append(operation) + return ApplyPatchResult(output=f"Created {operation.path}") + + def update_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + self.operations.append(operation) + return ApplyPatchResult(status="completed", output=f"Updated {operation.path}") + + def delete_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + self.operations.append(operation) + return ApplyPatchResult(output=f"Deleted {operation.path}") + + +@pytest.mark.asyncio +async def test_apply_patch_tool_success() -> None: + editor = RecordingEditor() + tool = ApplyPatchTool(editor=editor) + tool_call = DummyApplyPatchCall( + type="apply_patch_call", + call_id="call_apply", + operation={"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"}, + ) + tool_run = ToolRunApplyPatchCall(tool_call=tool_call, apply_patch_tool=tool) + agent = Agent(name="patcher", tools=[tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + result = await ApplyPatchAction.execute( + agent=agent, + call=tool_run, + hooks=RunHooks[Any](), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert "Updated tasks.md" in result.output + raw_item = cast(dict[str, Any], result.raw_item) + assert raw_item["type"] == "apply_patch_call_output" + assert raw_item["status"] == "completed" + assert raw_item["call_id"] == "call_apply" + assert editor.operations[0].type == "update_file" + assert isinstance(raw_item["output"], str) + assert raw_item["output"].startswith("Updated tasks.md") + input_payload = result.to_input_item() + assert isinstance(input_payload, dict) + payload_dict = cast(dict[str, Any], input_payload) + assert payload_dict["type"] == "apply_patch_call_output" + assert payload_dict["status"] == "completed" + + +@pytest.mark.asyncio +async def test_apply_patch_tool_failure() -> None: + class ExplodingEditor(RecordingEditor): + def update_file(self, operation: ApplyPatchOperation) -> ApplyPatchResult: + raise RuntimeError("boom") + + tool = ApplyPatchTool(editor=ExplodingEditor()) + tool_call = DummyApplyPatchCall( + type="apply_patch_call", + call_id="call_apply_fail", + operation={"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"}, + ) + tool_run = ToolRunApplyPatchCall(tool_call=tool_call, apply_patch_tool=tool) + agent = Agent(name="patcher", tools=[tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + result = await ApplyPatchAction.execute( + agent=agent, + call=tool_run, + hooks=RunHooks[Any](), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert "boom" in result.output + raw_item = cast(dict[str, Any], result.raw_item) + assert raw_item["status"] == "failed" + assert isinstance(raw_item.get("output"), str) + input_payload = result.to_input_item() + assert isinstance(input_payload, dict) + payload_dict = cast(dict[str, Any], input_payload) + assert payload_dict["type"] == "apply_patch_call_output" + assert payload_dict["status"] == "failed" + + +@pytest.mark.asyncio +async def test_apply_patch_tool_accepts_mapping_call() -> None: + editor = RecordingEditor() + tool = ApplyPatchTool(editor=editor) + tool_call: dict[str, Any] = { + "type": "apply_patch_call", + "call_id": "call_mapping", + "operation": { + "type": "create_file", + "path": "notes.md", + "diff": "+hello\n", + }, + } + tool_run = ToolRunApplyPatchCall(tool_call=tool_call, apply_patch_tool=tool) + agent = Agent(name="patcher", tools=[tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + result = await ApplyPatchAction.execute( + agent=agent, + call=tool_run, + hooks=RunHooks[Any](), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + raw_item = cast(dict[str, Any], result.raw_item) + assert raw_item["call_id"] == "call_mapping" + assert editor.operations[0].path == "notes.md" diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index a306b1841..53f3aa9d9 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -4,7 +4,7 @@ that screenshots are taken and wrapped appropriately, and that the execute function invokes hooks and returns the expected ToolCallOutputItem.""" -from typing import Any +from typing import Any, cast import pytest from openai.types.responses.response_computer_tool_call import ( @@ -304,9 +304,8 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None: assert output_item.agent is agent assert isinstance(output_item, ToolCallOutputItem) assert output_item.output == "data:image/png;base64,xyz" - raw = output_item.raw_item + raw = cast(dict[str, Any], output_item.raw_item) # Raw item is a dict-like mapping with expected output fields. - assert isinstance(raw, dict) assert raw["type"] == "computer_call_output" assert raw["output"]["type"] == "computer_screenshot" assert "image_url" in raw["output"] diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 9f227aadb..18107773d 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -344,3 +344,18 @@ async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> s assert len(tools_with_ctx) == 2 assert tools_with_ctx[0].name == "another_tool" assert tools_with_ctx[1].name == "third_tool" + + +@pytest.mark.asyncio +async def test_async_failure_error_function_is_awaited() -> None: + async def failure_handler(ctx: RunContextWrapper[Any], exc: Exception) -> str: + return f"handled:{exc}" + + @function_tool(failure_error_function=lambda ctx, exc: failure_handler(ctx, exc)) + def boom() -> None: + """Always raises to trigger the failure handler.""" + raise RuntimeError("kapow") + + ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}") + result = await boom.on_invoke_tool(ctx, "{}") + assert result.startswith("handled:") diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 4cf9ae832..49601bdab 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, cast import pytest from pydantic import BaseModel @@ -303,15 +303,18 @@ def assert_item_is_function_tool_call( item: RunItem, name: str, arguments: str | None = None ) -> None: assert isinstance(item, ToolCallItem) - assert item.raw_item.type == "function_call" - assert item.raw_item.name == name - assert not arguments or item.raw_item.arguments == arguments + raw_item = getattr(item, "raw_item", None) + assert getattr(raw_item, "type", None) == "function_call" + assert getattr(raw_item, "name", None) == name + if arguments: + assert getattr(raw_item, "arguments", None) == arguments def assert_item_is_function_tool_call_output(item: RunItem, output: str) -> None: assert isinstance(item, ToolCallOutputItem) - assert item.raw_item["type"] == "function_call_output" - assert item.raw_item["output"] == output + raw_item = cast(dict[str, Any], item.raw_item) + assert raw_item["type"] == "function_call_output" + assert raw_item["output"] == output async def get_execute_result( diff --git a/tests/test_shell_call_serialization.py b/tests/test_shell_call_serialization.py new file mode 100644 index 000000000..8a592954b --- /dev/null +++ b/tests/test_shell_call_serialization.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pytest + +from agents import _run_impl as run_impl +from agents.exceptions import ModelBehaviorError +from agents.tool import ShellCallOutcome, ShellCommandOutput + + +def test_coerce_shell_call_reads_max_output_length() -> None: + tool_call = { + "call_id": "shell-1", + "action": { + "commands": ["ls"], + "maxOutputLength": 512, + }, + "status": "in_progress", + } + result = run_impl._coerce_shell_call(tool_call) + assert result.action.max_output_length == 512 + + +def test_coerce_shell_call_requires_commands() -> None: + tool_call = {"call_id": "shell-2", "action": {"commands": []}} + with pytest.raises(ModelBehaviorError): + run_impl._coerce_shell_call(tool_call) + + +def test_normalize_shell_output_handles_timeout() -> None: + entry = { + "stdout": "", + "stderr": "", + "outcome": {"type": "timeout"}, + "provider_data": {"truncated": True}, + } + normalized = run_impl._normalize_shell_output(entry) + assert normalized.status == "timeout" + assert normalized.provider_data == {"truncated": True} + + +def test_normalize_shell_output_converts_string_outcome() -> None: + entry = { + "stdout": "hi", + "stderr": "", + "status": "completed", + "outcome": "success", + "exit_code": 0, + } + normalized = run_impl._normalize_shell_output(entry) + assert normalized.status == "completed" + assert normalized.exit_code in (None, 0) + + +def test_serialize_shell_output_emits_canonical_outcome() -> None: + output = ShellCommandOutput( + stdout="hello", + stderr="", + outcome=ShellCallOutcome(type="exit", exit_code=0), + ) + payload = run_impl._serialize_shell_output(output) + assert payload["outcome"]["type"] == "exit" + assert payload["outcome"]["exit_code"] == 0 + assert "exitCode" not in payload["outcome"] diff --git a/tests/test_shell_tool.py b/tests/test_shell_tool.py new file mode 100644 index 000000000..d2132d6a2 --- /dev/null +++ b/tests/test_shell_tool.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest + +from agents import ( + Agent, + RunConfig, + RunContextWrapper, + RunHooks, + ShellCallOutcome, + ShellCommandOutput, + ShellResult, + ShellTool, +) +from agents._run_impl import ShellAction, ToolRunShellCall +from agents.items import ToolCallOutputItem + + +@pytest.mark.asyncio +async def test_shell_tool_structured_output_is_rendered() -> None: + shell_tool = ShellTool( + executor=lambda request: ShellResult( + output=[ + ShellCommandOutput( + command="echo hi", + stdout="hi\n", + outcome=ShellCallOutcome(type="exit", exit_code=0), + ), + ShellCommandOutput( + command="ls", + stdout="README.md\nsrc\n", + stderr="warning", + outcome=ShellCallOutcome(type="exit", exit_code=1), + ), + ], + provider_data={"runner": "demo"}, + max_output_length=4096, + ) + ) + + tool_call = { + "type": "shell_call", + "id": "shell_call", + "call_id": "call_shell", + "status": "completed", + "action": { + "commands": ["echo hi", "ls"], + "timeout_ms": 1000, + "max_output_length": 4096, + }, + } + + tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool) + agent = Agent(name="shell-agent", tools=[shell_tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + result = await ShellAction.execute( + agent=agent, + call=tool_run, + hooks=RunHooks[Any](), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert "$ echo hi" in result.output + assert "stderr:\nwarning" in result.output + + raw_item = cast(dict[str, Any], result.raw_item) + assert raw_item["type"] == "shell_call_output" + assert raw_item["status"] == "completed" + assert raw_item["provider_data"]["runner"] == "demo" + assert raw_item["max_output_length"] == 4096 + shell_output = raw_item["shell_output"] + assert shell_output[1]["exit_code"] == 1 + assert isinstance(raw_item["output"], list) + first_output = raw_item["output"][0] + assert first_output["stdout"].startswith("hi") + assert first_output["outcome"]["type"] == "exit" + assert first_output["outcome"]["exit_code"] == 0 + assert "command" not in first_output + input_payload = result.to_input_item() + assert isinstance(input_payload, dict) + payload_dict = cast(dict[str, Any], input_payload) + assert payload_dict["type"] == "shell_call_output" + assert "status" not in payload_dict + assert "shell_output" not in payload_dict + assert "provider_data" not in payload_dict + + +@pytest.mark.asyncio +async def test_shell_tool_executor_failure_returns_error() -> None: + class ExplodingExecutor: + def __call__(self, request): + raise RuntimeError("boom") + + shell_tool = ShellTool(executor=ExplodingExecutor()) + tool_call = { + "type": "shell_call", + "id": "shell_call_fail", + "call_id": "call_shell_fail", + "status": "completed", + "action": {"commands": ["echo boom"], "timeout_ms": 1000}, + } + tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool) + agent = Agent(name="shell-agent", tools=[shell_tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + result = await ShellAction.execute( + agent=agent, + call=tool_run, + hooks=RunHooks[Any](), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert isinstance(result, ToolCallOutputItem) + assert "boom" in result.output + raw_item = cast(dict[str, Any], result.raw_item) + assert raw_item["type"] == "shell_call_output" + assert raw_item["status"] == "failed" + assert isinstance(raw_item["output"], list) + assert "boom" in raw_item["output"][0]["stdout"] + first_output = raw_item["output"][0] + assert first_output["outcome"]["type"] == "exit" + assert first_output["outcome"]["exit_code"] == 1 + assert "command" not in first_output + assert isinstance(raw_item["output"], list) + input_payload = result.to_input_item() + assert isinstance(input_payload, dict) + payload_dict = cast(dict[str, Any], input_payload) + assert payload_dict["type"] == "shell_call_output" + assert "status" not in payload_dict + assert "shell_output" not in payload_dict + assert "provider_data" not in payload_dict diff --git a/tests/test_tool_metadata.py b/tests/test_tool_metadata.py new file mode 100644 index 000000000..ad6395e9b --- /dev/null +++ b/tests/test_tool_metadata.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import cast + +from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp + +from agents.computer import Computer +from agents.run_context import RunContextWrapper +from agents.tool import ( + ApplyPatchTool, + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + ShellCallOutcome, + ShellCommandOutput, + ShellTool, + WebSearchTool, +) +from agents.tool_context import ToolContext + + +class DummyEditor: + def create_file(self, operation): + return None + + def update_file(self, operation): + return None + + def delete_file(self, operation): + return None + + +def test_tool_name_properties() -> None: + dummy_computer = cast(Computer, object()) + dummy_mcp = cast(Mcp, {"type": "mcp", "server_label": "demo"}) + dummy_code = cast(CodeInterpreter, {"type": "code_interpreter", "container": "python"}) + dummy_image = cast(ImageGeneration, {"type": "image_generation", "model": "gpt-image-1"}) + + assert FileSearchTool(vector_store_ids=[]).name == "file_search" + assert WebSearchTool().name == "web_search" + assert isinstance(ComputerTool(computer=dummy_computer).name, str) + assert HostedMCPTool(tool_config=dummy_mcp).name == "hosted_mcp" + assert CodeInterpreterTool(tool_config=dummy_code).name == "code_interpreter" + assert ImageGenerationTool(tool_config=dummy_image).name == "image_generation" + assert LocalShellTool(executor=lambda req: "ok").name == "local_shell" + assert ShellTool(executor=lambda req: "ok").type == "shell" + assert ApplyPatchTool(editor=DummyEditor()).type == "apply_patch" + + +def test_shell_command_output_status_property() -> None: + output = ShellCommandOutput(outcome=ShellCallOutcome(type="timeout")) + assert output.status == "timeout" + + +def test_tool_context_from_agent_context() -> None: + ctx = RunContextWrapper(context={"foo": "bar"}) + tool_call = ToolContext.from_agent_context( + ctx, + tool_call_id="123", + tool_call=type( + "Call", + (), + { + "name": "demo", + "arguments": "{}", + }, + )(), + ) + assert tool_call.tool_name == "demo"