From 043426487f6adab1567eb14b8761c554204f66ea Mon Sep 17 00:00:00 2001 From: Steven C Date: Tue, 23 Sep 2025 11:29:12 -0400 Subject: [PATCH 1/4] Adding tool input and output guardrails --- docs/guardrails.md | 80 +++++ docs/ref/tool_guardrails.md | 3 + examples/basic/tool_guardrails.py | 145 +++++++++ examples/basic/tools.py | 1 + mkdocs.yml | 1 + src/agents/__init__.py | 20 ++ src/agents/_run_impl.py | 88 ++++-- src/agents/exceptions.py | 35 +++ src/agents/tool.py | 8 + src/agents/tool_guardrails.py | 204 +++++++++++++ tests/test_tool_guardrails.py | 482 ++++++++++++++++++++++++++++++ 11 files changed, 1049 insertions(+), 18 deletions(-) create mode 100644 docs/ref/tool_guardrails.md create mode 100644 examples/basic/tool_guardrails.py create mode 100644 src/agents/tool_guardrails.py create mode 100644 tests/test_tool_guardrails.py diff --git a/docs/guardrails.md b/docs/guardrails.md index 8df904a4c..f73f9d4fa 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -152,3 +152,83 @@ async def main(): 2. This is the guardrail's output type. 3. This is the guardrail function that receives the agent's output, and returns the result. 4. This is the actual agent that defines the workflow. + +## Tool guardrails + +Tool guardrails provide fine-grained control over individual tool calls, allowing you to validate inputs and outputs at the tool level. This is particularly useful for: + +- Blocking sensitive data in tool arguments +- Preventing unauthorized access to certain tools +- Sanitizing tool outputs before they're returned +- Implementing custom validation logic for specific tools + +There are two types of tool guardrails: + +1. **Tool input guardrails** run before a tool is executed, validating the tool call arguments +2. **Tool output guardrails** run after a tool is executed, validating the tool's output + +### Tool input guardrails + +Tool input guardrails run in 3 steps: + +1. First, the guardrail receives the tool call data including arguments, context, and agent information +2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput] +3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolInputGuardrailTripwireTriggered`][agents.exceptions.ToolInputGuardrailTripwireTriggered] exception is raised + +### Tool output guardrails + +Tool output guardrails run in 3 steps: + +1. First, the guardrail receives the tool call data plus the tool's output +2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput] +3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolOutputGuardrailTripwireTriggered`][agents.exceptions.ToolOutputGuardrailTripwireTriggered] exception is raised + +### Implementing tool guardrails + +You can create tool guardrails using the `@tool_input_guardrail` and `@tool_output_guardrail` decorators: + +```python +from agents import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, + ToolOutputGuardrailData, + tool_input_guardrail, + tool_output_guardrail, +) + +@tool_input_guardrail +def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + """Block tool calls that contain sensitive words in arguments.""" + # Check arguments for sensitive content + if "password" in data.tool_call.arguments.lower(): + return ToolGuardrailFunctionOutput( + tripwire_triggered=True, + model_message="🚨 Tool call blocked: contains sensitive word", + output_info={"blocked_word": "password"}, + ) + return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated") + +@tool_output_guardrail +def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Block tool outputs that contain sensitive data.""" + if "ssn" in str(data.output).lower(): + return ToolGuardrailFunctionOutput( + tripwire_triggered=True, + model_message="🚨 Tool output blocked: contains sensitive data", + output_info={"blocked_pattern": "SSN"}, + ) + return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated") + +# Apply guardrails to tools +my_tool.tool_input_guardrails = [block_sensitive_words] +my_tool.tool_output_guardrails = [block_sensitive_output] +``` + +For a complete working example, see [tool_guardrails.py](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tool_guardrails.py). + +### Key differences from agent guardrails + +- **Scope**: Tool guardrails operate on individual tool calls, while agent guardrails operate on the entire agent input/output +- **Timing**: Tool guardrails run immediately before/after tool execution, while agent guardrails run at the beginning/end of agent execution +- **Data access**: Tool guardrails have access to the specific tool call arguments and outputs, plus the tool context +- **Application**: Tool guardrails are applied directly to function tools via the `tool_input_guardrails` and `tool_output_guardrails` attributes diff --git a/docs/ref/tool_guardrails.md b/docs/ref/tool_guardrails.md new file mode 100644 index 000000000..bc3639304 --- /dev/null +++ b/docs/ref/tool_guardrails.md @@ -0,0 +1,3 @@ +# `Tool Guardrails` + +::: agents.tool_guardrails diff --git a/examples/basic/tool_guardrails.py b/examples/basic/tool_guardrails.py new file mode 100644 index 000000000..c85d1c6eb --- /dev/null +++ b/examples/basic/tool_guardrails.py @@ -0,0 +1,145 @@ +import asyncio +import json + +from agents import ( + Agent, + Runner, + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrailData, + ToolOutputGuardrailTripwireTriggered, + function_tool, + tool_input_guardrail, + tool_output_guardrail, +) + + +@function_tool +def send_email(to: str, subject: str, body: str) -> str: + """Send an email to the specified recipient.""" + return f"Email sent to {to} with subject '{subject}'" + + +@function_tool +def get_user_data(user_id: str) -> dict[str, str]: + """Get user data by ID.""" + # Simulate returning sensitive data + return { + "user_id": user_id, + "name": "John Doe", + "email": "john@example.com", + "ssn": "123-45-6789", # Sensitive data that should be blocked! + "phone": "555-1234", + } + + +@tool_input_guardrail +def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + """Block tool calls that contain sensitive words in arguments.""" + try: + args = json.loads(data.tool_call.arguments) + except json.JSONDecodeError: + return ToolGuardrailFunctionOutput( + tripwire_triggered=False, output_info="Invalid JSON arguments" + ) + + # Check for suspicious content + sensitive_words = [ + "password", + "hack", + "exploit", + "malware", + "orange", + ] # to mock sensitive words + for key, value in args.items(): + value_str = str(value).lower() + for word in sensitive_words: + if word in value_str: + return ToolGuardrailFunctionOutput( + tripwire_triggered=True, + model_message=f"🚨 Tool call blocked: contains '{word}'", + output_info={"blocked_word": word, "argument": key}, + ) + + return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated") + + +@tool_output_guardrail +def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Block tool outputs that contain sensitive data.""" + output_str = str(data.output).lower() + + # Check for sensitive data patterns + if "ssn" in output_str or "123-45-6789" in output_str: + return ToolGuardrailFunctionOutput( + tripwire_triggered=True, + model_message="🚨 Tool output blocked: contains sensitive data", + output_info={"blocked_pattern": "SSN", "tool": data.tool_call.name}, + ) + + return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated") + + +# Apply guardrails to tools +send_email.tool_input_guardrails = [block_sensitive_words] +get_user_data.tool_output_guardrails = [block_sensitive_output] + +agent = Agent( + name="Secure Assistant", + instructions="You are a helpful assistant with access to email and user data tools.", + tools=[send_email, get_user_data], +) + + +async def main(): + print("=== Tool Guardrails Example ===\n") + + # Example 1: Normal operation - should work fine + print("1. Normal email sending:") + try: + result = await Runner.run(agent, "Send a welcome email to john@example.com") + print(f"✅ Success: {result.final_output}\n") + except Exception as e: + print(f"❌ Error: {e}\n") + + # Example 2: Input guardrail triggers - should block suspicious content + print("2. Attempting to send email with suspicious content:") + try: + result = await Runner.run( + agent, "Send an email to john@example.com with the subject: orange" + ) + print(f"✅ Success: {result.final_output}\n") + except ToolInputGuardrailTripwireTriggered as e: + print(f"🚨 Input guardrail triggered: {e.output.model_message}") + print(f" Details: {e.output.output_info}\n") + + # Example 3: Output guardrail triggers - should block sensitive data + print("3. Attempting to get user data (contains SSN):") + try: + result = await Runner.run(agent, "Get the data for user ID user123") + print(f"✅ Success: {result.final_output}\n") + except ToolOutputGuardrailTripwireTriggered as e: + print(f"🚨 Output guardrail triggered: {e.output.model_message}") + print(f" Details: {e.output.output_info}\n") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Example output: + +=== Tool Guardrails Example === + +1. Normal email sending: +✅ Success: I've sent a welcome email to john@example.com with an appropriate subject and greeting message. + +2. Attempting to send email with suspicious content: +🚨 Input guardrail triggered: 🚨 Tool call blocked: contains 'orange' + Details: {'blocked_word': 'orange', 'argument': 'subject'} + +3. Attempting to get user data (contains SSN): +🚨 Output guardrail triggered: 🚨 Tool output blocked: contains sensitive data + Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'} +""" diff --git a/examples/basic/tools.py b/examples/basic/tools.py index 1c4496603..2052d9427 100644 --- a/examples/basic/tools.py +++ b/examples/basic/tools.py @@ -18,6 +18,7 @@ def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weat print("[debug] get_weather called") return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.") + agent = Agent( name="Hello world", instructions="You are a helpful agent.", diff --git a/mkdocs.yml b/mkdocs.yml index bea747bed..c4096d6ab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,6 +101,7 @@ plugins: - ref/usage.md - ref/exceptions.md - ref/guardrail.md + - ref/tool_guardrails.md - ref/model_settings.md - ref/agent_output.md - ref/function_schema.md diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 3a8260f29..1c85950db 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -21,6 +21,8 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, RunErrorDetails, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrailTripwireTriggered, UserError, ) from .guardrail import ( @@ -83,6 +85,15 @@ default_tool_error_function, function_tool, ) +from .tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolInputGuardrailData, + ToolOutputGuardrail, + ToolOutputGuardrailData, + tool_input_guardrail, + tool_output_guardrail, +) from .tracing import ( AgentSpanData, CustomSpanData, @@ -191,6 +202,8 @@ def enable_verbose_stdout_logging(): "AgentsException", "InputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered", + "ToolInputGuardrailTripwireTriggered", + "ToolOutputGuardrailTripwireTriggered", "DynamicPromptFunction", "GenerateDynamicPromptData", "Prompt", @@ -204,6 +217,13 @@ def enable_verbose_stdout_logging(): "GuardrailFunctionOutput", "input_guardrail", "output_guardrail", + "ToolInputGuardrail", + "ToolOutputGuardrail", + "ToolGuardrailFunctionOutput", + "ToolInputGuardrailData", + "ToolOutputGuardrailData", + "tool_input_guardrail", + "tool_output_guardrail", "handoff", "Handoff", "HandoffInputData", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a2d872bf1..342cd569b 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -44,7 +44,13 @@ from .agent import Agent, ToolsToFinalOutputResult from .agent_output import AgentOutputSchemaBase from .computer import AsyncComputer, Computer -from .exceptions import AgentsException, ModelBehaviorError, UserError +from .exceptions import ( + AgentsException, + ModelBehaviorError, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrailTripwireTriggered, + UserError, +) from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .handoffs import Handoff, HandoffInputData from .items import ( @@ -80,6 +86,10 @@ Tool, ) from .tool_context import ToolContext +from .tool_guardrails import ( + ToolInputGuardrailData, + ToolOutputGuardrailData, +) from .tracing import ( SpanError, Trace, @@ -569,24 +579,66 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - _, _, result = await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), - ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - func_tool.on_invoke_tool(tool_context, tool_call.arguments), - ) + # 1) Run input tool guardrails, if any + final_result: Any | None = None + if func_tool.tool_input_guardrails: + for guardrail in func_tool.tool_input_guardrails: + gr_out = await guardrail.run( + ToolInputGuardrailData( + context=tool_context, + agent=agent, + tool_call=tool_call, + ) + ) + if gr_out.tripwire_triggered: + # Raise tripwire exception instead of just setting result + raise ToolInputGuardrailTripwireTriggered( + guardrail=guardrail, output=gr_out + ) + + if final_result is None: + # 2) Actually run the tool + await asyncio.gather( + hooks.on_tool_start(tool_context, agent, func_tool), + ( + agent.hooks.on_tool_start(tool_context, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + real_result = await func_tool.on_invoke_tool( + tool_context, tool_call.arguments + ) - await asyncio.gather( - hooks.on_tool_end(tool_context, agent, func_tool, result), - ( - agent.hooks.on_tool_end(tool_context, agent, func_tool, result) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + # 3) Run output tool guardrails, if any + final_result = real_result + if func_tool.tool_output_guardrails: + for output_guardrail in func_tool.tool_output_guardrails: + gr_out = await output_guardrail.run( + ToolOutputGuardrailData( + context=tool_context, + agent=agent, + tool_call=tool_call, + output=real_result, + ) + ) + if gr_out.tripwire_triggered: + # Raise tripwire exception instead of just setting result + raise ToolOutputGuardrailTripwireTriggered( + guardrail=output_guardrail, output=gr_out + ) + # 4) Tool end hooks (with final result, which may have been overridden) + await asyncio.gather( + hooks.on_tool_end(tool_context, agent, func_tool, final_result), + ( + agent.hooks.on_tool_end( + tool_context, agent, func_tool, final_result + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + result = final_result except Exception as e: _error_tracing.attach_error_to_current_span( SpanError( diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index c00024c2e..39518c39d 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -8,6 +8,11 @@ from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ModelResponse, RunItem, TResponseInputItem from .run_context import RunContextWrapper + from .tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolOutputGuardrail, + ) from .util._pretty_print import pretty_print_run_error_details @@ -94,3 +99,33 @@ def __init__(self, guardrail_result: OutputGuardrailResult): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + + +class ToolInputGuardrailTripwireTriggered(AgentsException): + """Exception raised when a tool input guardrail tripwire is triggered.""" + + guardrail: ToolInputGuardrail[Any] + """The guardrail that was triggered.""" + + output: ToolGuardrailFunctionOutput + """The output from the guardrail function.""" + + def __init__(self, guardrail: ToolInputGuardrail[Any], output: ToolGuardrailFunctionOutput): + self.guardrail = guardrail + self.output = output + super().__init__(f"Tool input guardrail {guardrail.__class__.__name__} triggered tripwire") + + +class ToolOutputGuardrailTripwireTriggered(AgentsException): + """Exception raised when a tool output guardrail tripwire is triggered.""" + + guardrail: ToolOutputGuardrail[Any] + """The guardrail that was triggered.""" + + output: ToolGuardrailFunctionOutput + """The output from the guardrail function.""" + + def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFunctionOutput): + self.guardrail = guardrail + self.output = output + super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire") diff --git a/src/agents/tool.py b/src/agents/tool.py index 7ba9435ed..2796754d5 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -27,6 +27,7 @@ from .run_context import RunContextWrapper from .strict_schema import ensure_strict_json_schema from .tool_context import ToolContext +from .tool_guardrails import ToolInputGuardrail, ToolOutputGuardrail from .tracing import SpanError from .util import _error_tracing from .util._types import MaybeAwaitable @@ -94,6 +95,13 @@ class FunctionTool: and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" + # Tool-specific guardrails + tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None + """Optional list of input guardrails to run before invoking this tool.""" + + tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None + """Optional list of output guardrails to run after invoking this tool.""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/src/agents/tool_guardrails.py b/src/agents/tool_guardrails.py new file mode 100644 index 000000000..93386c04e --- /dev/null +++ b/src/agents/tool_guardrails.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Generic, overload + +from openai.types.responses import ResponseFunctionToolCall +from typing_extensions import TypeVar + +from .exceptions import UserError +from .tool_context import ToolContext +from .util._types import MaybeAwaitable + +if TYPE_CHECKING: + from .agent import Agent + + +@dataclass +class ToolGuardrailFunctionOutput: + """The output of a tool guardrail function.""" + + output_info: Any + """ + Optional data about checks performed. For example, the guardrail could include + information about the checks it performed and granular results. + """ + + tripwire_triggered: bool + """ + Whether the tripwire was triggered. If triggered, the tool execution will be halted. + """ + + model_message: str | None = None + """ + Message to send back to the model as the tool output if tripped. + """ + + +@dataclass +class ToolInputGuardrailData: + """Input data passed to a tool input guardrail function.""" + + context: ToolContext[Any] + """ + The tool context containing information about the current tool execution. + """ + + agent: Agent[Any] + """ + The agent that is executing the tool. + """ + + tool_call: ResponseFunctionToolCall + """ + The tool call data including the function name and arguments. + """ + + +@dataclass +class ToolOutputGuardrailData(ToolInputGuardrailData): + """Input data passed to a tool output guardrail function. + + Extends input data with the tool's output. + """ + + output: Any + """ + The output produced by the tool function. + """ + + +TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) + + +@dataclass +class ToolInputGuardrail(Generic[TContext_co]): + """A guardrail that runs before a function tool is invoked.""" + + guardrail_function: Callable[ + [ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] + ] + """ + The function that implements the guardrail logic. + """ + + name: str | None = None + """ + Optional name for the guardrail. If not provided, uses the function name. + """ + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run(self, data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result + return result + + +@dataclass +class ToolOutputGuardrail(Generic[TContext_co]): + """A guardrail that runs after a function tool is invoked.""" + + guardrail_function: Callable[ + [ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] + ] + """ + The function that implements the guardrail logic. + """ + + name: str | None = None + """ + Optional name for the guardrail. If not provided, uses the function name. + """ + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run(self, data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result + return result + + +# Decorators +_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput] +_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_input_guardrail(func: _ToolInputFuncSync): ... + + +@overload +def tool_input_guardrail(func: _ToolInputFuncAsync): ... + + +@overload +def tool_input_guardrail( + *, name: str | None = None +) -> Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]]: ... + + +def tool_input_guardrail( + func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None, + *, + name: str | None = None, +) -> ( + ToolInputGuardrail[Any] + | Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]] +): + """Decorator to create a ToolInputGuardrail from a function.""" + + def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]: + return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator + + +_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput] +_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncSync): ... + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncAsync): ... + + +@overload +def tool_output_guardrail( + *, name: str | None = None +) -> Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]]: ... + + +def tool_output_guardrail( + func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None, + *, + name: str | None = None, +) -> ( + ToolOutputGuardrail[Any] + | Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]] +): + """Decorator to create a ToolOutputGuardrail from a function.""" + + def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]: + return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py new file mode 100644 index 000000000..0b36b8abc --- /dev/null +++ b/tests/test_tool_guardrails.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from openai.types.responses import ResponseFunctionToolCall + +from agents import ( + Agent, + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolInputGuardrailData, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrail, + ToolOutputGuardrailData, + ToolOutputGuardrailTripwireTriggered, + UserError, +) +from agents.tool_context import ToolContext +from agents.tool_guardrails import tool_input_guardrail, tool_output_guardrail + + +def get_mock_tool_call(arguments: str = "{}") -> ResponseFunctionToolCall: + """Helper to create a mock tool call for testing.""" + return ResponseFunctionToolCall( + call_id="call_123", type="function_call", name="test_tool", arguments=arguments + ) + + +def get_mock_tool_context() -> ToolContext: + """Helper to create a mock tool context for testing.""" + return ToolContext(context=None, tool_name="test_tool", tool_call_id="call_123") + + +def get_sync_input_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create a sync input guardrail function.""" + + def sync_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return sync_guardrail + + +def get_async_input_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create an async input guardrail function.""" + + async def async_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return async_guardrail + + +def get_sync_output_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create a sync output guardrail function.""" + + def sync_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return sync_guardrail + + +def get_async_output_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create an async output guardrail function.""" + + async def async_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return async_guardrail + + +@pytest.mark.asyncio +async def test_sync_tool_input_guardrail(): + """Test sync tool input guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=False) + ) + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + ) + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.tripwire_triggered + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_async_tool_input_guardrail(): + """Test async tool input guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=False) + ) + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + ) + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.tripwire_triggered + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_sync_tool_output_guardrail(): + """Test sync tool output guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=False) + ) + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="test output", + ) + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.tripwire_triggered + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_async_tool_output_guardrail(): + """Test async tool output guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=False) + ) + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="test output", + ) + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.tripwire_triggered + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.tripwire_triggered + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_invalid_tool_input_guardrail_raises_user_error(): + """Test that invalid guardrail functions raise UserError.""" + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail(guardrail_function="foo") # type: ignore + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + ) + await guardrail.run(data) + + +@pytest.mark.asyncio +async def test_invalid_tool_output_guardrail_raises_user_error(): + """Test that invalid guardrail functions raise UserError.""" + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail(guardrail_function="foo") # type: ignore + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="test output", + ) + await guardrail.run(data) + + +# Test decorators + + +@tool_input_guardrail +def decorated_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test_1", + tripwire_triggered=False, + ) + + +@tool_input_guardrail(name="Custom input name") +def decorated_named_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test_2", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_tool_input_guardrail_decorators(): + """Test input guardrail decorators.""" + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + ) + + # Test basic decorator + guardrail = decorated_input_guardrail + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info == "test_1" + + # Test named decorator + guardrail = decorated_named_input_guardrail + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info == "test_2" + assert guardrail.get_name() == "Custom input name" + + +@tool_output_guardrail +def decorated_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test_3", + tripwire_triggered=False, + ) + + +@tool_output_guardrail(name="Custom output name") +def decorated_named_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test_4", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_tool_output_guardrail_decorators(): + """Test output guardrail decorators.""" + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="test output", + ) + + # Test basic decorator + guardrail = decorated_output_guardrail + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info == "test_3" + + # Test named decorator + guardrail = decorated_named_output_guardrail + result = await guardrail.run(data) + assert not result.tripwire_triggered + assert result.output_info == "test_4" + assert guardrail.get_name() == "Custom output name" + + +# Test practical examples + + +@pytest.mark.asyncio +async def test_password_blocking_input_guardrail(): + """Test a realistic input guardrail that blocks passwords.""" + + @tool_input_guardrail + def check_for_password(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if "password" in data.tool_call.arguments.lower(): + return ToolGuardrailFunctionOutput( + output_info={"blocked_word": "password"}, + tripwire_triggered=True, + model_message="Tool call blocked: contains password", + ) + return ToolGuardrailFunctionOutput( + output_info="safe_input", + tripwire_triggered=False, + ) + + # Test with password - should trigger + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call('{"message": "Hello password world"}'), + ) + result = await check_for_password.run(data) + assert result.tripwire_triggered is True + assert result.model_message == "Tool call blocked: contains password" + assert result.output_info["blocked_word"] == "password" + + # Test without password - should pass + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call('{"message": "Hello safe world"}'), + ) + result = await check_for_password.run(data) + assert result.tripwire_triggered is False + assert result.output_info == "safe_input" + + +@pytest.mark.asyncio +async def test_ssn_blocking_output_guardrail(): + """Test a realistic output guardrail that blocks SSNs.""" + + @tool_output_guardrail + def check_for_ssn(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + output_str = str(data.output).lower() + if "ssn" in output_str or "123-45-6789" in output_str: + return ToolGuardrailFunctionOutput( + output_info={"blocked_pattern": "SSN"}, + tripwire_triggered=True, + model_message="Output blocked: contains SSN", + ) + return ToolGuardrailFunctionOutput( + output_info="safe_output", + tripwire_triggered=False, + ) + + # Test with SSN in output - should trigger + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="User SSN is 123-45-6789", + ) + result = await check_for_ssn.run(data) + assert result.tripwire_triggered is True + assert result.model_message == "Output blocked: contains SSN" + assert result.output_info["blocked_pattern"] == "SSN" + + # Test with safe output - should pass + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + tool_call=get_mock_tool_call(), + output="User name is John Doe", + ) + result = await check_for_ssn.run(data) + assert result.tripwire_triggered is False + assert result.output_info == "safe_output" + + +def test_tool_input_guardrail_exception(): + """Test the tool input guardrail tripwire exception.""" + + @tool_input_guardrail + def test_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test", + tripwire_triggered=True, + model_message="blocked", + ) + + output = ToolGuardrailFunctionOutput( + output_info="test", + tripwire_triggered=True, + model_message="blocked", + ) + + exception = ToolInputGuardrailTripwireTriggered( + guardrail=test_guardrail, + output=output, + ) + + assert exception.guardrail == test_guardrail + assert exception.output == output + assert "ToolInputGuardrail" in str(exception) + + +def test_tool_output_guardrail_exception(): + """Test the tool output guardrail tripwire exception.""" + + @tool_output_guardrail + def test_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test", + tripwire_triggered=True, + model_message="blocked", + ) + + output = ToolGuardrailFunctionOutput( + output_info="test", + tripwire_triggered=True, + model_message="blocked", + ) + + exception = ToolOutputGuardrailTripwireTriggered( + guardrail=test_guardrail, + output=output, + ) + + assert exception.guardrail == test_guardrail + assert exception.output == output + assert "ToolOutputGuardrail" in str(exception) + + +if __name__ == "__main__": + # Run a simple test to verify functionality + async def main(): + print("Testing tool guardrails...") + + @tool_input_guardrail + def test_guard(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput( + output_info="test_passed", + tripwire_triggered=False, + ) + + print(f"✅ Created guardrail: {test_guard.get_name()}") + print("✅ All basic tests passed!") + + asyncio.run(main()) From 55c949a778ac397aaf0f92eeb8f11255329ded5e Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 25 Sep 2025 10:21:58 -0400 Subject: [PATCH 2/4] removed docs. Added non-blocking behavior --- docs/guardrails.md | 82 +-------- docs/ref/tool_guardrails.md | 3 - examples/basic/tool_guardrails.py | 115 +++++++----- examples/basic/tools.py | 1 - mkdocs.yml | 1 - src/agents/__init__.py | 4 + src/agents/_run_impl.py | 77 +++++++- src/agents/result.py | 7 + src/agents/run.py | 9 + src/agents/tool_guardrails.py | 106 +++++++++-- tests/test_tool_guardrails.py | 287 ++++++++++++++++++------------ 11 files changed, 419 insertions(+), 273 deletions(-) delete mode 100644 docs/ref/tool_guardrails.md diff --git a/docs/guardrails.md b/docs/guardrails.md index f73f9d4fa..3fee17df0 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -151,84 +151,4 @@ async def main(): 1. This is the actual agent's output type. 2. This is the guardrail's output type. 3. This is the guardrail function that receives the agent's output, and returns the result. -4. This is the actual agent that defines the workflow. - -## Tool guardrails - -Tool guardrails provide fine-grained control over individual tool calls, allowing you to validate inputs and outputs at the tool level. This is particularly useful for: - -- Blocking sensitive data in tool arguments -- Preventing unauthorized access to certain tools -- Sanitizing tool outputs before they're returned -- Implementing custom validation logic for specific tools - -There are two types of tool guardrails: - -1. **Tool input guardrails** run before a tool is executed, validating the tool call arguments -2. **Tool output guardrails** run after a tool is executed, validating the tool's output - -### Tool input guardrails - -Tool input guardrails run in 3 steps: - -1. First, the guardrail receives the tool call data including arguments, context, and agent information -2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput] -3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolInputGuardrailTripwireTriggered`][agents.exceptions.ToolInputGuardrailTripwireTriggered] exception is raised - -### Tool output guardrails - -Tool output guardrails run in 3 steps: - -1. First, the guardrail receives the tool call data plus the tool's output -2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput] -3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolOutputGuardrailTripwireTriggered`][agents.exceptions.ToolOutputGuardrailTripwireTriggered] exception is raised - -### Implementing tool guardrails - -You can create tool guardrails using the `@tool_input_guardrail` and `@tool_output_guardrail` decorators: - -```python -from agents import ( - ToolGuardrailFunctionOutput, - ToolInputGuardrailData, - ToolOutputGuardrailData, - tool_input_guardrail, - tool_output_guardrail, -) - -@tool_input_guardrail -def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - """Block tool calls that contain sensitive words in arguments.""" - # Check arguments for sensitive content - if "password" in data.tool_call.arguments.lower(): - return ToolGuardrailFunctionOutput( - tripwire_triggered=True, - model_message="🚨 Tool call blocked: contains sensitive word", - output_info={"blocked_word": "password"}, - ) - return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated") - -@tool_output_guardrail -def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - """Block tool outputs that contain sensitive data.""" - if "ssn" in str(data.output).lower(): - return ToolGuardrailFunctionOutput( - tripwire_triggered=True, - model_message="🚨 Tool output blocked: contains sensitive data", - output_info={"blocked_pattern": "SSN"}, - ) - return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated") - -# Apply guardrails to tools -my_tool.tool_input_guardrails = [block_sensitive_words] -my_tool.tool_output_guardrails = [block_sensitive_output] -``` - -For a complete working example, see [tool_guardrails.py](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tool_guardrails.py). - -### Key differences from agent guardrails - -- **Scope**: Tool guardrails operate on individual tool calls, while agent guardrails operate on the entire agent input/output -- **Timing**: Tool guardrails run immediately before/after tool execution, while agent guardrails run at the beginning/end of agent execution -- **Data access**: Tool guardrails have access to the specific tool call arguments and outputs, plus the tool context -- **Application**: Tool guardrails are applied directly to function tools via the `tool_input_guardrails` and `tool_output_guardrails` attributes +4. This is the actual agent that defines the workflow. \ No newline at end of file diff --git a/docs/ref/tool_guardrails.md b/docs/ref/tool_guardrails.md deleted file mode 100644 index bc3639304..000000000 --- a/docs/ref/tool_guardrails.md +++ /dev/null @@ -1,3 +0,0 @@ -# `Tool Guardrails` - -::: agents.tool_guardrails diff --git a/examples/basic/tool_guardrails.py b/examples/basic/tool_guardrails.py index c85d1c6eb..0e7af6928 100644 --- a/examples/basic/tool_guardrails.py +++ b/examples/basic/tool_guardrails.py @@ -6,7 +6,6 @@ Runner, ToolGuardrailFunctionOutput, ToolInputGuardrailData, - ToolInputGuardrailTripwireTriggered, ToolOutputGuardrailData, ToolOutputGuardrailTripwireTriggered, function_tool, @@ -33,16 +32,24 @@ def get_user_data(user_id: str) -> dict[str, str]: "phone": "555-1234", } +@function_tool +def get_contact_info(user_id: str) -> dict[str, str]: + """Get contact info by ID.""" + return { + "user_id": user_id, + "name": "Jane Smith", + "email": "jane@example.com", + "phone": "555-1234", + } + @tool_input_guardrail -def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - """Block tool calls that contain sensitive words in arguments.""" +def reject_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + """Reject tool calls that contain sensitive words in arguments.""" try: - args = json.loads(data.tool_call.arguments) + args = json.loads(data.context.tool_arguments) if data.context.tool_arguments else {} except json.JSONDecodeError: - return ToolGuardrailFunctionOutput( - tripwire_triggered=False, output_info="Invalid JSON arguments" - ) + return ToolGuardrailFunctionOutput(output_info="Invalid JSON arguments") # Check for suspicious content sensitive_words = [ @@ -50,19 +57,19 @@ def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunction "hack", "exploit", "malware", - "orange", - ] # to mock sensitive words + "ACME", + ] for key, value in args.items(): value_str = str(value).lower() for word in sensitive_words: - if word in value_str: - return ToolGuardrailFunctionOutput( - tripwire_triggered=True, - model_message=f"🚨 Tool call blocked: contains '{word}'", + if word.lower() in value_str: + # Reject tool call and inform the model the function was not called + return ToolGuardrailFunctionOutput.reject_content( + message=f"🚨 Tool call blocked: contains '{word}'", output_info={"blocked_word": word, "argument": key}, ) - return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated") + return ToolGuardrailFunctionOutput(output_info="Input validated") @tool_output_guardrail @@ -72,57 +79,71 @@ def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFuncti # Check for sensitive data patterns if "ssn" in output_str or "123-45-6789" in output_str: - return ToolGuardrailFunctionOutput( - tripwire_triggered=True, - model_message="🚨 Tool output blocked: contains sensitive data", - output_info={"blocked_pattern": "SSN", "tool": data.tool_call.name}, + # Use raise_exception to halt execution completely for sensitive data + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"blocked_pattern": "SSN", "tool": data.context.tool_name}, ) - return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated") + return ToolGuardrailFunctionOutput(output_info="Output validated") + + +@tool_output_guardrail +def reject_phone_numbers(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Reject function output containing phone numbers.""" + output_str = str(data.output) + if "555-1234" in output_str: + return ToolGuardrailFunctionOutput.reject_content( + message="User data not retrieved as it contains a phone number which is restricted.", + output_info={"redacted": "phone_number"}, + ) + return ToolGuardrailFunctionOutput() # Apply guardrails to tools -send_email.tool_input_guardrails = [block_sensitive_words] +send_email.tool_input_guardrails = [reject_sensitive_words] get_user_data.tool_output_guardrails = [block_sensitive_output] +get_contact_info.tool_output_guardrails = [reject_phone_numbers] agent = Agent( name="Secure Assistant", instructions="You are a helpful assistant with access to email and user data tools.", - tools=[send_email, get_user_data], + tools=[send_email, get_user_data, get_contact_info], ) async def main(): print("=== Tool Guardrails Example ===\n") - # Example 1: Normal operation - should work fine - print("1. Normal email sending:") try: + # Example 1: Normal operation - should work fine + print("1. Normal email sending:") result = await Runner.run(agent, "Send a welcome email to john@example.com") - print(f"✅ Success: {result.final_output}\n") - except Exception as e: - print(f"❌ Error: {e}\n") + print(f"✅ Successful tool execution: {result.final_output}\n") - # Example 2: Input guardrail triggers - should block suspicious content - print("2. Attempting to send email with suspicious content:") - try: - result = await Runner.run( - agent, "Send an email to john@example.com with the subject: orange" - ) - print(f"✅ Success: {result.final_output}\n") - except ToolInputGuardrailTripwireTriggered as e: - print(f"🚨 Input guardrail triggered: {e.output.model_message}") - print(f" Details: {e.output.output_info}\n") + # Example 2: Input guardrail triggers - function tool call is rejected but execution continues + print("2. Attempting to send email with suspicious content:") + result = await Runner.run(agent, "Send an email to john@example.com introducing the company ACME corp.") + print(f"❌ Guardrail rejected function tool call: {result.final_output}\n") + except Exception as e: + print(f"Error: {e}\n") - # Example 3: Output guardrail triggers - should block sensitive data - print("3. Attempting to get user data (contains SSN):") try: + # Example 3: Output guardrail triggers - should raise exception for sensitive data + print("3. Attempting to get user data (contains SSN). Execution blocked:") result = await Runner.run(agent, "Get the data for user ID user123") - print(f"✅ Success: {result.final_output}\n") + print(f"✅ Successful tool execution: {result.final_output}\n") except ToolOutputGuardrailTripwireTriggered as e: - print(f"🚨 Output guardrail triggered: {e.output.model_message}") - print(f" Details: {e.output.output_info}\n") + print("🚨 Output guardrail triggered: Execution halted for sensitive data") + print(f"Details: {e.output.output_info}\n") + + try: + # Example 4: Output guardrail triggers - reject returning function tool output but continue execution + print("4. Rejecting function tool output containing phone numbers:") + result = await Runner.run(agent, "Get contact info for user456") + print(f"❌ Guardrail rejected function tool output: {result.final_output}\n") + except Exception as e: + print(f"Error: {e}\n") if __name__ == "__main__": asyncio.run(main()) @@ -133,13 +154,15 @@ async def main(): === Tool Guardrails Example === 1. Normal email sending: -✅ Success: I've sent a welcome email to john@example.com with an appropriate subject and greeting message. +✅ Successful tool execution: I've sent a welcome email to john@example.com with an appropriate subject and greeting message. 2. Attempting to send email with suspicious content: -🚨 Input guardrail triggered: 🚨 Tool call blocked: contains 'orange' - Details: {'blocked_word': 'orange', 'argument': 'subject'} +❌ Guardrail rejected function tool call: I'm unable to send the email mentioning ACME Corp as it was blocked by security guardrails. -3. Attempting to get user data (contains SSN): -🚨 Output guardrail triggered: 🚨 Tool output blocked: contains sensitive data +3. Attempting to get user data (contains SSN). Execution blocked: +🚨 Output guardrail triggered: Execution halted for sensitive data Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'} + +4. Rejecting function tool output containing sensitive data: +✅ Successful tool execution: User data retrieved (phone number redacted for privacy) """ diff --git a/examples/basic/tools.py b/examples/basic/tools.py index 2052d9427..1c4496603 100644 --- a/examples/basic/tools.py +++ b/examples/basic/tools.py @@ -18,7 +18,6 @@ def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weat print("[debug] get_weather called") return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.") - agent = Agent( name="Hello world", instructions="You are a helpful agent.", diff --git a/mkdocs.yml b/mkdocs.yml index c4096d6ab..bea747bed 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,7 +101,6 @@ plugins: - ref/usage.md - ref/exceptions.md - ref/guardrail.md - - ref/tool_guardrails.md - ref/model_settings.md - ref/agent_output.md - ref/function_schema.md diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 1c85950db..d6ff56d40 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -89,8 +89,10 @@ ToolGuardrailFunctionOutput, ToolInputGuardrail, ToolInputGuardrailData, + ToolInputGuardrailResult, ToolOutputGuardrail, ToolOutputGuardrailData, + ToolOutputGuardrailResult, tool_input_guardrail, tool_output_guardrail, ) @@ -221,7 +223,9 @@ def enable_verbose_stdout_logging(): "ToolOutputGuardrail", "ToolGuardrailFunctionOutput", "ToolInputGuardrailData", + "ToolInputGuardrailResult", "ToolOutputGuardrailData", + "ToolOutputGuardrailResult", "tool_input_guardrail", "tool_output_guardrail", "handoff", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 342cd569b..75c770dd2 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -88,7 +88,9 @@ from .tool_context import ToolContext from .tool_guardrails import ( ToolInputGuardrailData, + ToolInputGuardrailResult, ToolOutputGuardrailData, + ToolOutputGuardrailResult, ) from .tracing import ( SpanError, @@ -218,6 +220,12 @@ class SingleStepResult: next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain """The next step to take.""" + tool_input_guardrail_results: list[ToolInputGuardrailResult] + """Tool input guardrail results from this step.""" + + tool_output_guardrail_results: list[ToolOutputGuardrailResult] + """Tool output guardrail results from this step.""" + @property def generated_items(self) -> list[RunItem]: """Items generated during the agent run (i.e. everything generated after @@ -260,7 +268,10 @@ async def execute_tools_and_side_effects( new_step_items.extend(processed_response.new_items) # First, lets run the tool calls - function tools and computer actions - function_results, computer_results = await asyncio.gather( + ( + (function_results, tool_input_guardrail_results, tool_output_guardrail_results), + computer_results, + ) = await asyncio.gather( cls.execute_function_tool_calls( agent=agent, tool_runs=processed_response.functions, @@ -353,6 +364,8 @@ async def execute_tools_and_side_effects( final_output=final_output, hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) elif not output_schema or output_schema.is_plain_text(): return await cls.execute_final_output( @@ -364,6 +377,8 @@ async def execute_tools_and_side_effects( final_output=potential_final_output_text or "", hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) # If there's no final output, we can just run again @@ -373,6 +388,8 @@ async def execute_tools_and_side_effects( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepRunAgain(), + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) @classmethod @@ -566,7 +583,13 @@ async def execute_function_tool_calls( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, - ) -> list[FunctionToolResult]: + ) -> tuple[ + list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult] + ]: + # Collect guardrail results + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] + async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: @@ -587,14 +610,29 @@ async def run_single_tool( ToolInputGuardrailData( context=tool_context, agent=agent, - tool_call=tool_call, ) ) - if gr_out.tripwire_triggered: - # Raise tripwire exception instead of just setting result + + # Store the guardrail result + tool_input_guardrail_results.append( + ToolInputGuardrailResult( + guardrail=guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": raise ToolInputGuardrailTripwireTriggered( guardrail=guardrail, output=gr_out ) + elif gr_out.behavior["type"] == "reject_content": + # Set final_result to the message and skip tool execution + final_result = gr_out.behavior["message"] + break + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail or tool execution + continue if final_result is None: # 2) Actually run the tool @@ -618,15 +656,30 @@ async def run_single_tool( ToolOutputGuardrailData( context=tool_context, agent=agent, - tool_call=tool_call, output=real_result, ) ) - if gr_out.tripwire_triggered: - # Raise tripwire exception instead of just setting result + + # Store the guardrail result + tool_output_guardrail_results.append( + ToolOutputGuardrailResult( + guardrail=output_guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": raise ToolOutputGuardrailTripwireTriggered( guardrail=output_guardrail, output=gr_out ) + elif gr_out.behavior["type"] == "reject_content": + # Override the result with the guardrail message + final_result = gr_out.behavior["message"] + break + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail + continue # 4) Tool end hooks (with final result, which may have been overridden) await asyncio.gather( hooks.on_tool_end(tool_context, agent, func_tool, final_result), @@ -661,7 +714,7 @@ async def run_single_tool( results = await asyncio.gather(*tasks) - return [ + function_tool_results = [ FunctionToolResult( tool=tool_run.function_tool, output=result, @@ -674,6 +727,8 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results + @classmethod async def execute_local_shell_calls( cls, @@ -925,6 +980,8 @@ async def execute_final_output( final_output: Any, hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], + tool_input_guardrail_results: list[ToolInputGuardrailResult], + tool_output_guardrail_results: list[ToolOutputGuardrailResult], ) -> SingleStepResult: # Run the on_end hooks await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output) @@ -935,6 +992,8 @@ async def execute_final_output( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepFinalOutput(final_output), + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) @classmethod diff --git a/src/agents/result.py b/src/agents/result.py index 26609da28..a76a98b29 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel from .agent import Agent + from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult T = TypeVar("T") @@ -59,6 +60,12 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + tool_input_guardrail_results: list[ToolInputGuardrailResult] + """Tool input guardrail results from all tools executed during the run.""" + + tool_output_guardrail_results: list[ToolOutputGuardrailResult] + """Tool output guardrail results from all tools executed during the run.""" + context_wrapper: RunContextWrapper[Any] """The context wrapper for the agent run.""" diff --git a/src/agents/run.py b/src/agents/run.py index 722aa1558..0b4df3174 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -68,6 +68,7 @@ StreamEvent, ) from .tool import Tool +from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData from .usage import Usage @@ -494,6 +495,8 @@ async def run( ) input_guardrail_results: list[InputGuardrailResult] = [] + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] current_span: Span[AgentSpanData] | None = None current_agent = starting_agent @@ -584,6 +587,10 @@ async def run( original_input = turn_result.original_input generated_items = turn_result.generated_items + # Collect tool guardrail results from this turn + tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) + tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) + if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await self._run_output_guardrails( current_agent.output_guardrails + (run_config.output_guardrails or []), @@ -599,6 +606,8 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, ) await self._save_result_to_session(session, [], turn_result.new_step_items) diff --git a/src/agents/tool_guardrails.py b/src/agents/tool_guardrails.py index 93386c04e..d02bd0dc5 100644 --- a/src/agents/tool_guardrails.py +++ b/src/agents/tool_guardrails.py @@ -3,10 +3,9 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Generic, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, overload -from openai.types.responses import ResponseFunctionToolCall -from typing_extensions import TypeVar +from typing_extensions import TypedDict, TypeVar from .exceptions import UserError from .tool_context import ToolContext @@ -16,6 +15,47 @@ from .agent import Agent +@dataclass +class ToolInputGuardrailResult: + """The result of a tool input guardrail run.""" + + guardrail: ToolInputGuardrail[Any] + """The guardrail that was run.""" + + output: ToolGuardrailFunctionOutput + """The output of the guardrail function.""" + + +@dataclass +class ToolOutputGuardrailResult: + """The result of a tool output guardrail run.""" + + guardrail: ToolOutputGuardrail[Any] + """The guardrail that was run.""" + + output: ToolGuardrailFunctionOutput + """The output of the guardrail function.""" + + +class RejectContentBehavior(TypedDict): + """Rejects the tool call/output but continues execution with a message to the model.""" + + type: Literal["reject_content"] + message: str + + +class RaiseExceptionBehavior(TypedDict): + """Raises an exception to halt execution.""" + + type: Literal["raise_exception"] + + +class AllowBehavior(TypedDict): + """Allows normal tool execution to continue.""" + + type: Literal["allow"] + + @dataclass class ToolGuardrailFunctionOutput: """The output of a tool guardrail function.""" @@ -26,15 +66,58 @@ class ToolGuardrailFunctionOutput: information about the checks it performed and granular results. """ - tripwire_triggered: bool + behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior | None = None """ - Whether the tripwire was triggered. If triggered, the tool execution will be halted. + Defines how the system should respond when this guardrail result is processed. + - None/allow: Allow normal tool execution to continue without interference (default) + - reject_content: Reject the tool call/output but continue execution with a message to the model + - raise_exception: Halt execution by raising a ToolGuardrailTripwireTriggered exception """ - model_message: str | None = None - """ - Message to send back to the model as the tool output if tripped. - """ + def __post_init__(self) -> None: + """Set default behavior if none specified.""" + if self.behavior is None: + self.behavior = AllowBehavior(type="allow") + + @classmethod + def allow(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that allows the tool execution to continue normally. + + Args: + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to allow normal execution. + """ + return cls(output_info=output_info, behavior=AllowBehavior(type="allow")) + + @classmethod + def reject_content(cls, message: str, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that rejects the tool call/output but continues execution. + + Args: + message: Message to send to the model instead of the tool result. + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to reject the content. + """ + return cls( + output_info=output_info, + behavior=RejectContentBehavior(type="reject_content", message=message), + ) + + @classmethod + def raise_exception(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that raises an exception to halt execution. + + Args: + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to raise an exception. + """ + return cls(output_info=output_info, behavior=RaiseExceptionBehavior(type="raise_exception")) @dataclass @@ -51,11 +134,6 @@ class ToolInputGuardrailData: The agent that is executing the tool. """ - tool_call: ResponseFunctionToolCall - """ - The tool call data including the function name and arguments. - """ - @dataclass class ToolOutputGuardrailData(ToolInputGuardrailData): diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py index 0b36b8abc..8ccaec0ad 100644 --- a/tests/test_tool_guardrails.py +++ b/tests/test_tool_guardrails.py @@ -4,7 +4,6 @@ from typing import Any import pytest -from openai.types.responses import ResponseFunctionToolCall from agents import ( Agent, @@ -21,26 +20,24 @@ from agents.tool_guardrails import tool_input_guardrail, tool_output_guardrail -def get_mock_tool_call(arguments: str = "{}") -> ResponseFunctionToolCall: - """Helper to create a mock tool call for testing.""" - return ResponseFunctionToolCall( - call_id="call_123", type="function_call", name="test_tool", arguments=arguments - ) - - -def get_mock_tool_context() -> ToolContext: +def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolContext: """Helper to create a mock tool context for testing.""" - return ToolContext(context=None, tool_name="test_tool", tool_call_id="call_123") + return ToolContext( + context=None, + tool_name="test_tool", + tool_call_id="call_123", + tool_arguments=tool_arguments, + ) def get_sync_input_guardrail(triggers: bool, output_info: Any | None = None): """Helper to create a sync input guardrail function.""" def sync_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info=output_info, - tripwire_triggered=triggers, - ) + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) return sync_guardrail @@ -49,10 +46,10 @@ def get_async_input_guardrail(triggers: bool, output_info: Any | None = None): """Helper to create an async input guardrail function.""" async def async_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info=output_info, - tripwire_triggered=triggers, - ) + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) return async_guardrail @@ -61,10 +58,10 @@ def get_sync_output_guardrail(triggers: bool, output_info: Any | None = None): """Helper to create a sync output guardrail function.""" def sync_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info=output_info, - tripwire_triggered=triggers, - ) + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) return sync_guardrail @@ -73,10 +70,10 @@ def get_async_output_guardrail(triggers: bool, output_info: Any | None = None): """Helper to create an async output guardrail function.""" async def async_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info=output_info, - tripwire_triggered=triggers, - ) + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) return async_guardrail @@ -91,10 +88,9 @@ async def test_sync_tool_input_guardrail(): data = ToolInputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), ) result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info is None # Test triggering guardrail @@ -102,7 +98,7 @@ async def test_sync_tool_input_guardrail(): guardrail_function=get_sync_input_guardrail(triggers=True) ) result = await guardrail_2.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info is None # Test triggering guardrail with output info @@ -110,7 +106,7 @@ async def test_sync_tool_input_guardrail(): guardrail_function=get_sync_input_guardrail(triggers=True, output_info="test_info") ) result = await guardrail_3.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info == "test_info" @@ -124,10 +120,9 @@ async def test_async_tool_input_guardrail(): data = ToolInputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), ) result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info is None # Test triggering guardrail @@ -135,7 +130,7 @@ async def test_async_tool_input_guardrail(): guardrail_function=get_async_input_guardrail(triggers=True) ) result = await guardrail_2.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info is None # Test triggering guardrail with output info @@ -143,7 +138,7 @@ async def test_async_tool_input_guardrail(): guardrail_function=get_async_input_guardrail(triggers=True, output_info="test_info") ) result = await guardrail_3.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info == "test_info" @@ -157,11 +152,10 @@ async def test_sync_tool_output_guardrail(): data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="test output", ) result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info is None # Test triggering guardrail @@ -169,7 +163,7 @@ async def test_sync_tool_output_guardrail(): guardrail_function=get_sync_output_guardrail(triggers=True) ) result = await guardrail_2.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info is None # Test triggering guardrail with output info @@ -177,7 +171,7 @@ async def test_sync_tool_output_guardrail(): guardrail_function=get_sync_output_guardrail(triggers=True, output_info="test_info") ) result = await guardrail_3.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info == "test_info" @@ -191,11 +185,10 @@ async def test_async_tool_output_guardrail(): data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="test output", ) result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info is None # Test triggering guardrail @@ -203,7 +196,7 @@ async def test_async_tool_output_guardrail(): guardrail_function=get_async_output_guardrail(triggers=True) ) result = await guardrail_2.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info is None # Test triggering guardrail with output info @@ -211,7 +204,7 @@ async def test_async_tool_output_guardrail(): guardrail_function=get_async_output_guardrail(triggers=True, output_info="test_info") ) result = await guardrail_3.run(data) - assert result.tripwire_triggered + assert result.behavior["type"] == "raise_exception" assert result.output_info == "test_info" @@ -224,7 +217,6 @@ async def test_invalid_tool_input_guardrail_raises_user_error(): data = ToolInputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), ) await guardrail.run(data) @@ -238,7 +230,6 @@ async def test_invalid_tool_output_guardrail_raises_user_error(): data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="test output", ) await guardrail.run(data) @@ -249,18 +240,12 @@ async def test_invalid_tool_output_guardrail_raises_user_error(): @tool_input_guardrail def decorated_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test_1", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput.allow(output_info="test_1") @tool_input_guardrail(name="Custom input name") def decorated_named_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test_2", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput.allow(output_info="test_2") @pytest.mark.asyncio @@ -269,37 +254,30 @@ async def test_tool_input_guardrail_decorators(): data = ToolInputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), ) # Test basic decorator guardrail = decorated_input_guardrail result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info == "test_1" # Test named decorator guardrail = decorated_named_input_guardrail result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info == "test_2" assert guardrail.get_name() == "Custom input name" @tool_output_guardrail def decorated_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test_3", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput.allow(output_info="test_3") @tool_output_guardrail(name="Custom output name") def decorated_named_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test_4", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput.allow(output_info="test_4") @pytest.mark.asyncio @@ -308,20 +286,19 @@ async def test_tool_output_guardrail_decorators(): data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="test output", ) # Test basic decorator guardrail = decorated_output_guardrail result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info == "test_3" # Test named decorator guardrail = decorated_named_output_guardrail result = await guardrail.run(data) - assert not result.tripwire_triggered + assert result.behavior["type"] == "allow" assert result.output_info == "test_4" assert guardrail.get_name() == "Custom output name" @@ -335,36 +312,30 @@ async def test_password_blocking_input_guardrail(): @tool_input_guardrail def check_for_password(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - if "password" in data.tool_call.arguments.lower(): - return ToolGuardrailFunctionOutput( + if "password" in data.context.tool_arguments.lower(): + return ToolGuardrailFunctionOutput.reject_content( + message="Tool call blocked: contains password", output_info={"blocked_word": "password"}, - tripwire_triggered=True, - model_message="Tool call blocked: contains password", ) - return ToolGuardrailFunctionOutput( - output_info="safe_input", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput(output_info="safe_input") # Test with password - should trigger data = ToolInputGuardrailData( - context=get_mock_tool_context(), + context=get_mock_tool_context('{"message": "Hello password world"}'), agent=Agent(name="test"), - tool_call=get_mock_tool_call('{"message": "Hello password world"}'), ) result = await check_for_password.run(data) - assert result.tripwire_triggered is True - assert result.model_message == "Tool call blocked: contains password" + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Tool call blocked: contains password" assert result.output_info["blocked_word"] == "password" # Test without password - should pass data = ToolInputGuardrailData( - context=get_mock_tool_context(), + context=get_mock_tool_context('{"message": "Hello safe world"}'), agent=Agent(name="test"), - tool_call=get_mock_tool_call('{"message": "Hello safe world"}'), ) result = await check_for_password.run(data) - assert result.tripwire_triggered is False + assert result.behavior["type"] == "allow" assert result.output_info == "safe_input" @@ -376,37 +347,29 @@ async def test_ssn_blocking_output_guardrail(): def check_for_ssn(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: output_str = str(data.output).lower() if "ssn" in output_str or "123-45-6789" in output_str: - return ToolGuardrailFunctionOutput( - output_info={"blocked_pattern": "SSN"}, - tripwire_triggered=True, - model_message="Output blocked: contains SSN", + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"blocked_pattern": "SSN"} ) - return ToolGuardrailFunctionOutput( - output_info="safe_output", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput(output_info="safe_output") # Test with SSN in output - should trigger data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="User SSN is 123-45-6789", ) result = await check_for_ssn.run(data) - assert result.tripwire_triggered is True - assert result.model_message == "Output blocked: contains SSN" + assert result.behavior["type"] == "raise_exception" assert result.output_info["blocked_pattern"] == "SSN" # Test with safe output - should pass data = ToolOutputGuardrailData( context=get_mock_tool_context(), agent=Agent(name="test"), - tool_call=get_mock_tool_call(), output="User name is John Doe", ) result = await check_for_ssn.run(data) - assert result.tripwire_triggered is False + assert result.behavior["type"] == "allow" assert result.output_info == "safe_output" @@ -415,17 +378,9 @@ def test_tool_input_guardrail_exception(): @tool_input_guardrail def test_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test", - tripwire_triggered=True, - model_message="blocked", - ) + return ToolGuardrailFunctionOutput.raise_exception(output_info="test") - output = ToolGuardrailFunctionOutput( - output_info="test", - tripwire_triggered=True, - model_message="blocked", - ) + output = ToolGuardrailFunctionOutput.raise_exception(output_info="test") exception = ToolInputGuardrailTripwireTriggered( guardrail=test_guardrail, @@ -442,17 +397,9 @@ def test_tool_output_guardrail_exception(): @tool_output_guardrail def test_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test", - tripwire_triggered=True, - model_message="blocked", - ) + return ToolGuardrailFunctionOutput.raise_exception(output_info="test") - output = ToolGuardrailFunctionOutput( - output_info="test", - tripwire_triggered=True, - model_message="blocked", - ) + output = ToolGuardrailFunctionOutput.raise_exception(output_info="test") exception = ToolOutputGuardrailTripwireTriggered( guardrail=test_guardrail, @@ -464,6 +411,113 @@ def test_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput assert "ToolOutputGuardrail" in str(exception) +# Test new behavior system + + +@pytest.mark.asyncio +async def test_allow_behavior(): + """Test the allow behavior type.""" + + @tool_input_guardrail + def allow_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="allowed") + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await allow_guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "allowed" + + +@pytest.mark.asyncio +async def test_reject_content_behavior(): + """Test the reject_content behavior type.""" + + @tool_input_guardrail + def reject_content_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.reject_content( + message="Tool blocked by guardrail", output_info="rejected" + ) + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await reject_content_guardrail.run(data) + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Tool blocked by guardrail" + assert result.output_info == "rejected" + + +@pytest.mark.asyncio +async def test_raise_exception_behavior(): + """Test the raise_exception behavior type.""" + + @tool_input_guardrail + def raise_exception_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.raise_exception(output_info="exception") + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await raise_exception_guardrail.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "exception" + + +@pytest.mark.asyncio +async def test_mixed_behavior_output_guardrail(): + """Test mixing different behavior types in output guardrails.""" + + @tool_output_guardrail + def mixed_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + output_str = str(data.output).lower() + if "dangerous" in output_str: + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"reason": "dangerous_content"} + ) + elif "sensitive" in output_str: + return ToolGuardrailFunctionOutput.reject_content( + message="Content was filtered", output_info={"reason": "sensitive_content"} + ) + else: + return ToolGuardrailFunctionOutput(output_info={"status": "clean"}) + + # Test dangerous content (should raise exception) + data_dangerous = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is dangerous content", + ) + result = await mixed_guardrail.run(data_dangerous) + assert result.behavior["type"] == "raise_exception" + assert result.output_info["reason"] == "dangerous_content" + + # Test sensitive content (should reject content) + data_sensitive = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is sensitive data", + ) + result = await mixed_guardrail.run(data_sensitive) + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Content was filtered" + assert result.output_info["reason"] == "sensitive_content" + + # Test clean content (should allow) + data_clean = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is clean content", + ) + result = await mixed_guardrail.run(data_clean) + assert result.behavior["type"] == "allow" + assert result.output_info["status"] == "clean" + + if __name__ == "__main__": # Run a simple test to verify functionality async def main(): @@ -471,10 +525,7 @@ async def main(): @tool_input_guardrail def test_guard(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: - return ToolGuardrailFunctionOutput( - output_info="test_passed", - tripwire_triggered=False, - ) + return ToolGuardrailFunctionOutput.allow(output_info="test_passed") print(f"✅ Created guardrail: {test_guard.get_name()}") print("✅ All basic tests passed!") From 198c8bdbd6c9c8b9490a92e446cdc8c5a06335a2 Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 25 Sep 2025 11:25:04 -0400 Subject: [PATCH 3/4] Extract nested tool guardrail logic in _run_impl --- examples/basic/tool_guardrails.py | 11 +- src/agents/_run_impl.py | 249 +++++++++++++++++++++--------- src/agents/run.py | 2 + 3 files changed, 186 insertions(+), 76 deletions(-) diff --git a/examples/basic/tool_guardrails.py b/examples/basic/tool_guardrails.py index 0e7af6928..6151bf3a0 100644 --- a/examples/basic/tool_guardrails.py +++ b/examples/basic/tool_guardrails.py @@ -32,6 +32,7 @@ def get_user_data(user_id: str) -> dict[str, str]: "phone": "555-1234", } + @function_tool def get_contact_info(user_id: str) -> dict[str, str]: """Get contact info by ID.""" @@ -122,7 +123,9 @@ async def main(): # Example 2: Input guardrail triggers - function tool call is rejected but execution continues print("2. Attempting to send email with suspicious content:") - result = await Runner.run(agent, "Send an email to john@example.com introducing the company ACME corp.") + result = await Runner.run( + agent, "Send an email to john@example.com introducing the company ACME corp." + ) print(f"❌ Guardrail rejected function tool call: {result.final_output}\n") except Exception as e: print(f"Error: {e}\n") @@ -136,7 +139,6 @@ async def main(): print("🚨 Output guardrail triggered: Execution halted for sensitive data") print(f"Details: {e.output.output_info}\n") - try: # Example 4: Output guardrail triggers - reject returning function tool output but continue execution print("4. Rejecting function tool output containing phone numbers:") @@ -145,6 +147,7 @@ async def main(): except Exception as e: print(f"Error: {e}\n") + if __name__ == "__main__": asyncio.run(main()) @@ -157,12 +160,12 @@ async def main(): ✅ Successful tool execution: I've sent a welcome email to john@example.com with an appropriate subject and greeting message. 2. Attempting to send email with suspicious content: -❌ Guardrail rejected function tool call: I'm unable to send the email mentioning ACME Corp as it was blocked by security guardrails. +❌ Guardrail rejected function tool call: I'm unable to send the email as mentioning ACME Corp. is restricted. 3. Attempting to get user data (contains SSN). Execution blocked: 🚨 Output guardrail triggered: Execution halted for sensitive data Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'} 4. Rejecting function tool output containing sensitive data: -✅ Successful tool execution: User data retrieved (phone number redacted for privacy) +❌ Guardrail rejected function tool output: I'm unable to retrieve the contact info for user456 because it contains restricted information. """ diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 75c770dd2..b410b21b4 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -341,6 +341,8 @@ async def execute_tools_and_side_effects( final_output=check_tool_use.final_output, hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) # Now we can check if the model also produced a final output @@ -574,6 +576,155 @@ def process_model_response( mcp_approval_requests=mcp_approval_requests, ) + @classmethod + async def _execute_input_guardrails( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + tool_input_guardrail_results: list[ToolInputGuardrailResult], + ) -> str | None: + """Execute input guardrails for a tool. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + tool_input_guardrail_results: List to append guardrail results to. + + Returns: + None if tool execution should proceed, or a message string if execution should be + skipped. + + Raises: + ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception. + """ + if not func_tool.tool_input_guardrails: + return None + + for guardrail in func_tool.tool_input_guardrails: + gr_out = await guardrail.run( + ToolInputGuardrailData( + context=tool_context, + agent=agent, + ) + ) + + # Store the guardrail result + tool_input_guardrail_results.append( + ToolInputGuardrailResult( + guardrail=guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": + raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out) + elif gr_out.behavior["type"] == "reject_content": + # Set final_result to the message and skip tool execution + return gr_out.behavior["message"] + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail or tool execution + continue + + return None + + @classmethod + async def _execute_output_guardrails( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + real_result: Any, + tool_output_guardrail_results: list[ToolOutputGuardrailResult], + ) -> Any: + """Execute output guardrails for a tool. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + real_result: The actual result from the tool execution. + tool_output_guardrail_results: List to append guardrail results to. + + Returns: + The final result after guardrail processing (may be modified). + + Raises: + ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception. + """ + if not func_tool.tool_output_guardrails: + return real_result + + final_result = real_result + for output_guardrail in func_tool.tool_output_guardrails: + gr_out = await output_guardrail.run( + ToolOutputGuardrailData( + context=tool_context, + agent=agent, + output=real_result, + ) + ) + + # Store the guardrail result + tool_output_guardrail_results.append( + ToolOutputGuardrailResult( + guardrail=output_guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": + raise ToolOutputGuardrailTripwireTriggered( + guardrail=output_guardrail, output=gr_out + ) + elif gr_out.behavior["type"] == "reject_content": + # Override the result with the guardrail message + final_result = gr_out.behavior["message"] + break + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail + continue + + return final_result + + @classmethod + async def _execute_tool_with_hooks( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + hooks: RunHooks[TContext], + tool_call: ResponseFunctionToolCall, + ) -> Any: + """Execute the core tool function with before/after hooks. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + hooks: The run hooks to execute. + tool_call: The tool call details. + + Returns: + The result from the tool execution. + """ + await asyncio.gather( + hooks.on_tool_start(tool_context, agent, func_tool), + ( + agent.hooks.on_tool_start(tool_context, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return await func_tool.on_invoke_tool(tool_context, tool_call.arguments) + @classmethod async def execute_function_tool_calls( cls, @@ -603,83 +754,35 @@ async def run_single_tool( span_fn.span_data.input = tool_call.arguments try: # 1) Run input tool guardrails, if any - final_result: Any | None = None - if func_tool.tool_input_guardrails: - for guardrail in func_tool.tool_input_guardrails: - gr_out = await guardrail.run( - ToolInputGuardrailData( - context=tool_context, - agent=agent, - ) - ) - - # Store the guardrail result - tool_input_guardrail_results.append( - ToolInputGuardrailResult( - guardrail=guardrail, - output=gr_out, - ) - ) - - # Handle different behavior types - if gr_out.behavior["type"] == "raise_exception": - raise ToolInputGuardrailTripwireTriggered( - guardrail=guardrail, output=gr_out - ) - elif gr_out.behavior["type"] == "reject_content": - # Set final_result to the message and skip tool execution - final_result = gr_out.behavior["message"] - break - elif gr_out.behavior["type"] == "allow": - # Continue to next guardrail or tool execution - continue - - if final_result is None: + rejected_message = await cls._execute_input_guardrails( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + tool_input_guardrail_results=tool_input_guardrail_results, + ) + + if rejected_message is not None: + # Input guardrail rejected the tool call + final_result = rejected_message + else: # 2) Actually run the tool - await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), - ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - real_result = await func_tool.on_invoke_tool( - tool_context, tool_call.arguments + real_result = await cls._execute_tool_with_hooks( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + hooks=hooks, + tool_call=tool_call, ) # 3) Run output tool guardrails, if any - final_result = real_result - if func_tool.tool_output_guardrails: - for output_guardrail in func_tool.tool_output_guardrails: - gr_out = await output_guardrail.run( - ToolOutputGuardrailData( - context=tool_context, - agent=agent, - output=real_result, - ) - ) + final_result = await cls._execute_output_guardrails( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + real_result=real_result, + tool_output_guardrail_results=tool_output_guardrail_results, + ) - # Store the guardrail result - tool_output_guardrail_results.append( - ToolOutputGuardrailResult( - guardrail=output_guardrail, - output=gr_out, - ) - ) - - # Handle different behavior types - if gr_out.behavior["type"] == "raise_exception": - raise ToolOutputGuardrailTripwireTriggered( - guardrail=output_guardrail, output=gr_out - ) - elif gr_out.behavior["type"] == "reject_content": - # Override the result with the guardrail message - final_result = gr_out.behavior["message"] - break - elif gr_out.behavior["type"] == "allow": - # Continue to next guardrail - continue # 4) Tool end hooks (with final result, which may have been overridden) await asyncio.gather( hooks.on_tool_end(tool_context, agent, func_tool, final_result), @@ -932,6 +1035,8 @@ async def execute_handoffs( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepHandoff(new_agent), + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], ) @classmethod diff --git a/src/agents/run.py b/src/agents/run.py index 0b4df3174..b8f9dfebf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -715,6 +715,8 @@ def run_streamed( max_turns=max_turns, input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, context_wrapper=context_wrapper, From cdd930849fcef3ddb93e255a3e12bc0c716915f3 Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 25 Sep 2025 11:49:58 -0400 Subject: [PATCH 4/4] make check formatting --- examples/basic/tool_guardrails.py | 2 +- src/agents/tool_guardrails.py | 13 +++++-------- .../memory/test_advanced_sqlite_session.py | 2 ++ tests/test_result_cast.py | 2 ++ 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/basic/tool_guardrails.py b/examples/basic/tool_guardrails.py index 6151bf3a0..661d66b71 100644 --- a/examples/basic/tool_guardrails.py +++ b/examples/basic/tool_guardrails.py @@ -97,7 +97,7 @@ def reject_phone_numbers(data: ToolOutputGuardrailData) -> ToolGuardrailFunction message="User data not retrieved as it contains a phone number which is restricted.", output_info={"redacted": "phone_number"}, ) - return ToolGuardrailFunctionOutput() + return ToolGuardrailFunctionOutput(output_info="Phone number check passed") # Apply guardrails to tools diff --git a/src/agents/tool_guardrails.py b/src/agents/tool_guardrails.py index d02bd0dc5..545a11761 100644 --- a/src/agents/tool_guardrails.py +++ b/src/agents/tool_guardrails.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Awaitable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, overload from typing_extensions import TypedDict, TypeVar @@ -66,19 +66,16 @@ class ToolGuardrailFunctionOutput: information about the checks it performed and granular results. """ - behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior | None = None + behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior = field( + default_factory=lambda: AllowBehavior(type="allow") + ) """ Defines how the system should respond when this guardrail result is processed. - - None/allow: Allow normal tool execution to continue without interference (default) + - allow: Allow normal tool execution to continue without interference (default) - reject_content: Reject the tool call/output but continue execution with a message to the model - raise_exception: Halt execution by raising a ToolGuardrailTripwireTriggered exception """ - def __post_init__(self) -> None: - """Set default behavior if none specified.""" - if self.behavior is None: - self.behavior = AllowBehavior(type="allow") - @classmethod def allow(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: """Create a guardrail output that allows the tool execution to continue normally. diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index d352e6c40..40edb99fe 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -70,6 +70,8 @@ def create_mock_run_result( final_output="test output", input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], context_wrapper=context_wrapper, _last_agent=agent, ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index c621e7352..4ef1a293d 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -14,6 +14,8 @@ def create_run_result(final_output: Any) -> RunResult: final_output=final_output, input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], _last_agent=Agent(name="test"), context_wrapper=RunContextWrapper(context=None), )