diff --git a/docs/guardrails.md b/docs/guardrails.md index 8df904a4c..3fee17df0 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -151,4 +151,4 @@ async def main(): 1. This is the actual agent's output type. 2. This is the guardrail's output type. 3. This is the guardrail function that receives the agent's output, and returns the result. -4. This is the actual agent that defines the workflow. +4. This is the actual agent that defines the workflow. \ No newline at end of file diff --git a/examples/basic/tool_guardrails.py b/examples/basic/tool_guardrails.py new file mode 100644 index 000000000..661d66b71 --- /dev/null +++ b/examples/basic/tool_guardrails.py @@ -0,0 +1,171 @@ +import asyncio +import json + +from agents import ( + Agent, + Runner, + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, + ToolOutputGuardrailData, + ToolOutputGuardrailTripwireTriggered, + function_tool, + tool_input_guardrail, + tool_output_guardrail, +) + + +@function_tool +def send_email(to: str, subject: str, body: str) -> str: + """Send an email to the specified recipient.""" + return f"Email sent to {to} with subject '{subject}'" + + +@function_tool +def get_user_data(user_id: str) -> dict[str, str]: + """Get user data by ID.""" + # Simulate returning sensitive data + return { + "user_id": user_id, + "name": "John Doe", + "email": "john@example.com", + "ssn": "123-45-6789", # Sensitive data that should be blocked! + "phone": "555-1234", + } + + +@function_tool +def get_contact_info(user_id: str) -> dict[str, str]: + """Get contact info by ID.""" + return { + "user_id": user_id, + "name": "Jane Smith", + "email": "jane@example.com", + "phone": "555-1234", + } + + +@tool_input_guardrail +def reject_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + """Reject tool calls that contain sensitive words in arguments.""" + try: + args = json.loads(data.context.tool_arguments) if data.context.tool_arguments else {} + except json.JSONDecodeError: + return ToolGuardrailFunctionOutput(output_info="Invalid JSON arguments") + + # Check for suspicious content + sensitive_words = [ + "password", + "hack", + "exploit", + "malware", + "ACME", + ] + for key, value in args.items(): + value_str = str(value).lower() + for word in sensitive_words: + if word.lower() in value_str: + # Reject tool call and inform the model the function was not called + return ToolGuardrailFunctionOutput.reject_content( + message=f"🚨 Tool call blocked: contains '{word}'", + output_info={"blocked_word": word, "argument": key}, + ) + + return ToolGuardrailFunctionOutput(output_info="Input validated") + + +@tool_output_guardrail +def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Block tool outputs that contain sensitive data.""" + output_str = str(data.output).lower() + + # Check for sensitive data patterns + if "ssn" in output_str or "123-45-6789" in output_str: + # Use raise_exception to halt execution completely for sensitive data + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"blocked_pattern": "SSN", "tool": data.context.tool_name}, + ) + + return ToolGuardrailFunctionOutput(output_info="Output validated") + + +@tool_output_guardrail +def reject_phone_numbers(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Reject function output containing phone numbers.""" + output_str = str(data.output) + if "555-1234" in output_str: + return ToolGuardrailFunctionOutput.reject_content( + message="User data not retrieved as it contains a phone number which is restricted.", + output_info={"redacted": "phone_number"}, + ) + return ToolGuardrailFunctionOutput(output_info="Phone number check passed") + + +# Apply guardrails to tools +send_email.tool_input_guardrails = [reject_sensitive_words] +get_user_data.tool_output_guardrails = [block_sensitive_output] +get_contact_info.tool_output_guardrails = [reject_phone_numbers] + +agent = Agent( + name="Secure Assistant", + instructions="You are a helpful assistant with access to email and user data tools.", + tools=[send_email, get_user_data, get_contact_info], +) + + +async def main(): + print("=== Tool Guardrails Example ===\n") + + try: + # Example 1: Normal operation - should work fine + print("1. Normal email sending:") + result = await Runner.run(agent, "Send a welcome email to john@example.com") + print(f"✅ Successful tool execution: {result.final_output}\n") + + # Example 2: Input guardrail triggers - function tool call is rejected but execution continues + print("2. Attempting to send email with suspicious content:") + result = await Runner.run( + agent, "Send an email to john@example.com introducing the company ACME corp." + ) + print(f"❌ Guardrail rejected function tool call: {result.final_output}\n") + except Exception as e: + print(f"Error: {e}\n") + + try: + # Example 3: Output guardrail triggers - should raise exception for sensitive data + print("3. Attempting to get user data (contains SSN). Execution blocked:") + result = await Runner.run(agent, "Get the data for user ID user123") + print(f"✅ Successful tool execution: {result.final_output}\n") + except ToolOutputGuardrailTripwireTriggered as e: + print("🚨 Output guardrail triggered: Execution halted for sensitive data") + print(f"Details: {e.output.output_info}\n") + + try: + # Example 4: Output guardrail triggers - reject returning function tool output but continue execution + print("4. Rejecting function tool output containing phone numbers:") + result = await Runner.run(agent, "Get contact info for user456") + print(f"❌ Guardrail rejected function tool output: {result.final_output}\n") + except Exception as e: + print(f"Error: {e}\n") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Example output: + +=== Tool Guardrails Example === + +1. Normal email sending: +✅ Successful tool execution: I've sent a welcome email to john@example.com with an appropriate subject and greeting message. + +2. Attempting to send email with suspicious content: +❌ Guardrail rejected function tool call: I'm unable to send the email as mentioning ACME Corp. is restricted. + +3. Attempting to get user data (contains SSN). Execution blocked: +🚨 Output guardrail triggered: Execution halted for sensitive data + Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'} + +4. Rejecting function tool output containing sensitive data: +❌ Guardrail rejected function tool output: I'm unable to retrieve the contact info for user456 because it contains restricted information. +""" diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 3a8260f29..d6ff56d40 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -21,6 +21,8 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, RunErrorDetails, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrailTripwireTriggered, UserError, ) from .guardrail import ( @@ -83,6 +85,17 @@ default_tool_error_function, function_tool, ) +from .tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolInputGuardrailData, + ToolInputGuardrailResult, + ToolOutputGuardrail, + ToolOutputGuardrailData, + ToolOutputGuardrailResult, + tool_input_guardrail, + tool_output_guardrail, +) from .tracing import ( AgentSpanData, CustomSpanData, @@ -191,6 +204,8 @@ def enable_verbose_stdout_logging(): "AgentsException", "InputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered", + "ToolInputGuardrailTripwireTriggered", + "ToolOutputGuardrailTripwireTriggered", "DynamicPromptFunction", "GenerateDynamicPromptData", "Prompt", @@ -204,6 +219,15 @@ def enable_verbose_stdout_logging(): "GuardrailFunctionOutput", "input_guardrail", "output_guardrail", + "ToolInputGuardrail", + "ToolOutputGuardrail", + "ToolGuardrailFunctionOutput", + "ToolInputGuardrailData", + "ToolInputGuardrailResult", + "ToolOutputGuardrailData", + "ToolOutputGuardrailResult", + "tool_input_guardrail", + "tool_output_guardrail", "handoff", "Handoff", "HandoffInputData", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a2d872bf1..b410b21b4 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -44,7 +44,13 @@ from .agent import Agent, ToolsToFinalOutputResult from .agent_output import AgentOutputSchemaBase from .computer import AsyncComputer, Computer -from .exceptions import AgentsException, ModelBehaviorError, UserError +from .exceptions import ( + AgentsException, + ModelBehaviorError, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrailTripwireTriggered, + UserError, +) from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .handoffs import Handoff, HandoffInputData from .items import ( @@ -80,6 +86,12 @@ Tool, ) from .tool_context import ToolContext +from .tool_guardrails import ( + ToolInputGuardrailData, + ToolInputGuardrailResult, + ToolOutputGuardrailData, + ToolOutputGuardrailResult, +) from .tracing import ( SpanError, Trace, @@ -208,6 +220,12 @@ class SingleStepResult: next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain """The next step to take.""" + tool_input_guardrail_results: list[ToolInputGuardrailResult] + """Tool input guardrail results from this step.""" + + tool_output_guardrail_results: list[ToolOutputGuardrailResult] + """Tool output guardrail results from this step.""" + @property def generated_items(self) -> list[RunItem]: """Items generated during the agent run (i.e. everything generated after @@ -250,7 +268,10 @@ async def execute_tools_and_side_effects( new_step_items.extend(processed_response.new_items) # First, lets run the tool calls - function tools and computer actions - function_results, computer_results = await asyncio.gather( + ( + (function_results, tool_input_guardrail_results, tool_output_guardrail_results), + computer_results, + ) = await asyncio.gather( cls.execute_function_tool_calls( agent=agent, tool_runs=processed_response.functions, @@ -320,6 +341,8 @@ async def execute_tools_and_side_effects( final_output=check_tool_use.final_output, hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) # Now we can check if the model also produced a final output @@ -343,6 +366,8 @@ async def execute_tools_and_side_effects( final_output=final_output, hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) elif not output_schema or output_schema.is_plain_text(): return await cls.execute_final_output( @@ -354,6 +379,8 @@ async def execute_tools_and_side_effects( final_output=potential_final_output_text or "", hooks=hooks, context_wrapper=context_wrapper, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) # If there's no final output, we can just run again @@ -363,6 +390,8 @@ async def execute_tools_and_side_effects( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepRunAgain(), + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) @classmethod @@ -547,6 +576,155 @@ def process_model_response( mcp_approval_requests=mcp_approval_requests, ) + @classmethod + async def _execute_input_guardrails( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + tool_input_guardrail_results: list[ToolInputGuardrailResult], + ) -> str | None: + """Execute input guardrails for a tool. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + tool_input_guardrail_results: List to append guardrail results to. + + Returns: + None if tool execution should proceed, or a message string if execution should be + skipped. + + Raises: + ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception. + """ + if not func_tool.tool_input_guardrails: + return None + + for guardrail in func_tool.tool_input_guardrails: + gr_out = await guardrail.run( + ToolInputGuardrailData( + context=tool_context, + agent=agent, + ) + ) + + # Store the guardrail result + tool_input_guardrail_results.append( + ToolInputGuardrailResult( + guardrail=guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": + raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out) + elif gr_out.behavior["type"] == "reject_content": + # Set final_result to the message and skip tool execution + return gr_out.behavior["message"] + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail or tool execution + continue + + return None + + @classmethod + async def _execute_output_guardrails( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + real_result: Any, + tool_output_guardrail_results: list[ToolOutputGuardrailResult], + ) -> Any: + """Execute output guardrails for a tool. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + real_result: The actual result from the tool execution. + tool_output_guardrail_results: List to append guardrail results to. + + Returns: + The final result after guardrail processing (may be modified). + + Raises: + ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception. + """ + if not func_tool.tool_output_guardrails: + return real_result + + final_result = real_result + for output_guardrail in func_tool.tool_output_guardrails: + gr_out = await output_guardrail.run( + ToolOutputGuardrailData( + context=tool_context, + agent=agent, + output=real_result, + ) + ) + + # Store the guardrail result + tool_output_guardrail_results.append( + ToolOutputGuardrailResult( + guardrail=output_guardrail, + output=gr_out, + ) + ) + + # Handle different behavior types + if gr_out.behavior["type"] == "raise_exception": + raise ToolOutputGuardrailTripwireTriggered( + guardrail=output_guardrail, output=gr_out + ) + elif gr_out.behavior["type"] == "reject_content": + # Override the result with the guardrail message + final_result = gr_out.behavior["message"] + break + elif gr_out.behavior["type"] == "allow": + # Continue to next guardrail + continue + + return final_result + + @classmethod + async def _execute_tool_with_hooks( + cls, + *, + func_tool: FunctionTool, + tool_context: ToolContext[TContext], + agent: Agent[TContext], + hooks: RunHooks[TContext], + tool_call: ResponseFunctionToolCall, + ) -> Any: + """Execute the core tool function with before/after hooks. + + Args: + func_tool: The function tool being executed. + tool_context: The tool execution context. + agent: The agent executing the tool. + hooks: The run hooks to execute. + tool_call: The tool call details. + + Returns: + The result from the tool execution. + """ + await asyncio.gather( + hooks.on_tool_start(tool_context, agent, func_tool), + ( + agent.hooks.on_tool_start(tool_context, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return await func_tool.on_invoke_tool(tool_context, tool_call.arguments) + @classmethod async def execute_function_tool_calls( cls, @@ -556,7 +734,13 @@ async def execute_function_tool_calls( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, - ) -> list[FunctionToolResult]: + ) -> tuple[ + list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult] + ]: + # Collect guardrail results + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] + async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: @@ -569,24 +753,48 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - _, _, result = await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), - ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - func_tool.on_invoke_tool(tool_context, tool_call.arguments), + # 1) Run input tool guardrails, if any + rejected_message = await cls._execute_input_guardrails( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + tool_input_guardrail_results=tool_input_guardrail_results, ) - await asyncio.gather( - hooks.on_tool_end(tool_context, agent, func_tool, result), - ( - agent.hooks.on_tool_end(tool_context, agent, func_tool, result) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + if rejected_message is not None: + # Input guardrail rejected the tool call + final_result = rejected_message + else: + # 2) Actually run the tool + real_result = await cls._execute_tool_with_hooks( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + hooks=hooks, + tool_call=tool_call, + ) + + # 3) Run output tool guardrails, if any + final_result = await cls._execute_output_guardrails( + func_tool=func_tool, + tool_context=tool_context, + agent=agent, + real_result=real_result, + tool_output_guardrail_results=tool_output_guardrail_results, + ) + + # 4) Tool end hooks (with final result, which may have been overridden) + await asyncio.gather( + hooks.on_tool_end(tool_context, agent, func_tool, final_result), + ( + agent.hooks.on_tool_end( + tool_context, agent, func_tool, final_result + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + result = final_result except Exception as e: _error_tracing.attach_error_to_current_span( SpanError( @@ -609,7 +817,7 @@ async def run_single_tool( results = await asyncio.gather(*tasks) - return [ + function_tool_results = [ FunctionToolResult( tool=tool_run.function_tool, output=result, @@ -622,6 +830,8 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results + @classmethod async def execute_local_shell_calls( cls, @@ -825,6 +1035,8 @@ async def execute_handoffs( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepHandoff(new_agent), + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], ) @classmethod @@ -873,6 +1085,8 @@ async def execute_final_output( final_output: Any, hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], + tool_input_guardrail_results: list[ToolInputGuardrailResult], + tool_output_guardrail_results: list[ToolOutputGuardrailResult], ) -> SingleStepResult: # Run the on_end hooks await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output) @@ -883,6 +1097,8 @@ async def execute_final_output( pre_step_items=pre_step_items, new_step_items=new_step_items, next_step=NextStepFinalOutput(final_output), + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, ) @classmethod diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index c00024c2e..39518c39d 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -8,6 +8,11 @@ from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ModelResponse, RunItem, TResponseInputItem from .run_context import RunContextWrapper + from .tool_guardrails import ( + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolOutputGuardrail, + ) from .util._pretty_print import pretty_print_run_error_details @@ -94,3 +99,33 @@ def __init__(self, guardrail_result: OutputGuardrailResult): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + + +class ToolInputGuardrailTripwireTriggered(AgentsException): + """Exception raised when a tool input guardrail tripwire is triggered.""" + + guardrail: ToolInputGuardrail[Any] + """The guardrail that was triggered.""" + + output: ToolGuardrailFunctionOutput + """The output from the guardrail function.""" + + def __init__(self, guardrail: ToolInputGuardrail[Any], output: ToolGuardrailFunctionOutput): + self.guardrail = guardrail + self.output = output + super().__init__(f"Tool input guardrail {guardrail.__class__.__name__} triggered tripwire") + + +class ToolOutputGuardrailTripwireTriggered(AgentsException): + """Exception raised when a tool output guardrail tripwire is triggered.""" + + guardrail: ToolOutputGuardrail[Any] + """The guardrail that was triggered.""" + + output: ToolGuardrailFunctionOutput + """The output from the guardrail function.""" + + def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFunctionOutput): + self.guardrail = guardrail + self.output = output + super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire") diff --git a/src/agents/result.py b/src/agents/result.py index 26609da28..a76a98b29 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel from .agent import Agent + from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult T = TypeVar("T") @@ -59,6 +60,12 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + tool_input_guardrail_results: list[ToolInputGuardrailResult] + """Tool input guardrail results from all tools executed during the run.""" + + tool_output_guardrail_results: list[ToolOutputGuardrailResult] + """Tool output guardrail results from all tools executed during the run.""" + context_wrapper: RunContextWrapper[Any] """The context wrapper for the agent run.""" diff --git a/src/agents/run.py b/src/agents/run.py index 722aa1558..b8f9dfebf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -68,6 +68,7 @@ StreamEvent, ) from .tool import Tool +from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData from .usage import Usage @@ -494,6 +495,8 @@ async def run( ) input_guardrail_results: list[InputGuardrailResult] = [] + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] current_span: Span[AgentSpanData] | None = None current_agent = starting_agent @@ -584,6 +587,10 @@ async def run( original_input = turn_result.original_input generated_items = turn_result.generated_items + # Collect tool guardrail results from this turn + tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) + tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) + if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await self._run_output_guardrails( current_agent.output_guardrails + (run_config.output_guardrails or []), @@ -599,6 +606,8 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, ) await self._save_result_to_session(session, [], turn_result.new_step_items) @@ -706,6 +715,8 @@ def run_streamed( max_turns=max_turns, input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, context_wrapper=context_wrapper, diff --git a/src/agents/tool.py b/src/agents/tool.py index 7ba9435ed..2796754d5 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -27,6 +27,7 @@ from .run_context import RunContextWrapper from .strict_schema import ensure_strict_json_schema from .tool_context import ToolContext +from .tool_guardrails import ToolInputGuardrail, ToolOutputGuardrail from .tracing import SpanError from .util import _error_tracing from .util._types import MaybeAwaitable @@ -94,6 +95,13 @@ class FunctionTool: and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" + # Tool-specific guardrails + tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None + """Optional list of input guardrails to run before invoking this tool.""" + + tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None + """Optional list of output guardrails to run after invoking this tool.""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/src/agents/tool_guardrails.py b/src/agents/tool_guardrails.py new file mode 100644 index 000000000..545a11761 --- /dev/null +++ b/src/agents/tool_guardrails.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, overload + +from typing_extensions import TypedDict, TypeVar + +from .exceptions import UserError +from .tool_context import ToolContext +from .util._types import MaybeAwaitable + +if TYPE_CHECKING: + from .agent import Agent + + +@dataclass +class ToolInputGuardrailResult: + """The result of a tool input guardrail run.""" + + guardrail: ToolInputGuardrail[Any] + """The guardrail that was run.""" + + output: ToolGuardrailFunctionOutput + """The output of the guardrail function.""" + + +@dataclass +class ToolOutputGuardrailResult: + """The result of a tool output guardrail run.""" + + guardrail: ToolOutputGuardrail[Any] + """The guardrail that was run.""" + + output: ToolGuardrailFunctionOutput + """The output of the guardrail function.""" + + +class RejectContentBehavior(TypedDict): + """Rejects the tool call/output but continues execution with a message to the model.""" + + type: Literal["reject_content"] + message: str + + +class RaiseExceptionBehavior(TypedDict): + """Raises an exception to halt execution.""" + + type: Literal["raise_exception"] + + +class AllowBehavior(TypedDict): + """Allows normal tool execution to continue.""" + + type: Literal["allow"] + + +@dataclass +class ToolGuardrailFunctionOutput: + """The output of a tool guardrail function.""" + + output_info: Any + """ + Optional data about checks performed. For example, the guardrail could include + information about the checks it performed and granular results. + """ + + behavior: RejectContentBehavior | RaiseExceptionBehavior | AllowBehavior = field( + default_factory=lambda: AllowBehavior(type="allow") + ) + """ + Defines how the system should respond when this guardrail result is processed. + - allow: Allow normal tool execution to continue without interference (default) + - reject_content: Reject the tool call/output but continue execution with a message to the model + - raise_exception: Halt execution by raising a ToolGuardrailTripwireTriggered exception + """ + + @classmethod + def allow(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that allows the tool execution to continue normally. + + Args: + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to allow normal execution. + """ + return cls(output_info=output_info, behavior=AllowBehavior(type="allow")) + + @classmethod + def reject_content(cls, message: str, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that rejects the tool call/output but continues execution. + + Args: + message: Message to send to the model instead of the tool result. + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to reject the content. + """ + return cls( + output_info=output_info, + behavior=RejectContentBehavior(type="reject_content", message=message), + ) + + @classmethod + def raise_exception(cls, output_info: Any = None) -> ToolGuardrailFunctionOutput: + """Create a guardrail output that raises an exception to halt execution. + + Args: + output_info: Optional data about checks performed. + + Returns: + ToolGuardrailFunctionOutput configured to raise an exception. + """ + return cls(output_info=output_info, behavior=RaiseExceptionBehavior(type="raise_exception")) + + +@dataclass +class ToolInputGuardrailData: + """Input data passed to a tool input guardrail function.""" + + context: ToolContext[Any] + """ + The tool context containing information about the current tool execution. + """ + + agent: Agent[Any] + """ + The agent that is executing the tool. + """ + + +@dataclass +class ToolOutputGuardrailData(ToolInputGuardrailData): + """Input data passed to a tool output guardrail function. + + Extends input data with the tool's output. + """ + + output: Any + """ + The output produced by the tool function. + """ + + +TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) + + +@dataclass +class ToolInputGuardrail(Generic[TContext_co]): + """A guardrail that runs before a function tool is invoked.""" + + guardrail_function: Callable[ + [ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] + ] + """ + The function that implements the guardrail logic. + """ + + name: str | None = None + """ + Optional name for the guardrail. If not provided, uses the function name. + """ + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run(self, data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result + return result + + +@dataclass +class ToolOutputGuardrail(Generic[TContext_co]): + """A guardrail that runs after a function tool is invoked.""" + + guardrail_function: Callable[ + [ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput] + ] + """ + The function that implements the guardrail logic. + """ + + name: str | None = None + """ + Optional name for the guardrail. If not provided, uses the function name. + """ + + def get_name(self) -> str: + return self.name or self.guardrail_function.__name__ + + async def run(self, data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + result = self.guardrail_function(data) + if inspect.isawaitable(result): + return await result + return result + + +# Decorators +_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput] +_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_input_guardrail(func: _ToolInputFuncSync): ... + + +@overload +def tool_input_guardrail(func: _ToolInputFuncAsync): ... + + +@overload +def tool_input_guardrail( + *, name: str | None = None +) -> Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]]: ... + + +def tool_input_guardrail( + func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None, + *, + name: str | None = None, +) -> ( + ToolInputGuardrail[Any] + | Callable[[_ToolInputFuncSync | _ToolInputFuncAsync], ToolInputGuardrail[Any]] +): + """Decorator to create a ToolInputGuardrail from a function.""" + + def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]: + return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator + + +_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput] +_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]] + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncSync): ... + + +@overload +def tool_output_guardrail(func: _ToolOutputFuncAsync): ... + + +@overload +def tool_output_guardrail( + *, name: str | None = None +) -> Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]]: ... + + +def tool_output_guardrail( + func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None, + *, + name: str | None = None, +) -> ( + ToolOutputGuardrail[Any] + | Callable[[_ToolOutputFuncSync | _ToolOutputFuncAsync], ToolOutputGuardrail[Any]] +): + """Decorator to create a ToolOutputGuardrail from a function.""" + + def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]: + return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__) + + if func is not None: + return decorator(func) + return decorator diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index d352e6c40..40edb99fe 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -70,6 +70,8 @@ def create_mock_run_result( final_output="test output", input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], context_wrapper=context_wrapper, _last_agent=agent, ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index c621e7352..4ef1a293d 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -14,6 +14,8 @@ def create_run_result(final_output: Any) -> RunResult: final_output=final_output, input_guardrail_results=[], output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], _last_agent=Agent(name="test"), context_wrapper=RunContextWrapper(context=None), ) diff --git a/tests/test_tool_guardrails.py b/tests/test_tool_guardrails.py new file mode 100644 index 000000000..8ccaec0ad --- /dev/null +++ b/tests/test_tool_guardrails.py @@ -0,0 +1,533 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from agents import ( + Agent, + ToolGuardrailFunctionOutput, + ToolInputGuardrail, + ToolInputGuardrailData, + ToolInputGuardrailTripwireTriggered, + ToolOutputGuardrail, + ToolOutputGuardrailData, + ToolOutputGuardrailTripwireTriggered, + UserError, +) +from agents.tool_context import ToolContext +from agents.tool_guardrails import tool_input_guardrail, tool_output_guardrail + + +def get_mock_tool_context(tool_arguments: str = '{"param": "value"}') -> ToolContext: + """Helper to create a mock tool context for testing.""" + return ToolContext( + context=None, + tool_name="test_tool", + tool_call_id="call_123", + tool_arguments=tool_arguments, + ) + + +def get_sync_input_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create a sync input guardrail function.""" + + def sync_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) + + return sync_guardrail + + +def get_async_input_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create an async input guardrail function.""" + + async def async_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) + + return async_guardrail + + +def get_sync_output_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create a sync output guardrail function.""" + + def sync_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) + + return sync_guardrail + + +def get_async_output_guardrail(triggers: bool, output_info: Any | None = None): + """Helper to create an async output guardrail function.""" + + async def async_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + if triggers: + return ToolGuardrailFunctionOutput.raise_exception(output_info=output_info) + else: + return ToolGuardrailFunctionOutput.allow(output_info=output_info) + + return async_guardrail + + +@pytest.mark.asyncio +async def test_sync_tool_input_guardrail(): + """Test sync tool input guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=False) + ) + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_sync_input_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_async_tool_input_guardrail(): + """Test async tool input guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=False) + ) + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolInputGuardrail[Any] = ToolInputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_sync_tool_output_guardrail(): + """Test sync tool output guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=False) + ) + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="test output", + ) + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_sync_output_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_async_tool_output_guardrail(): + """Test async tool output guardrail execution.""" + # Test non-triggering guardrail + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=False) + ) + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="test output", + ) + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info is None + + # Test triggering guardrail + guardrail_2: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=True) + ) + result = await guardrail_2.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info is None + + # Test triggering guardrail with output info + guardrail_3: ToolOutputGuardrail[Any] = ToolOutputGuardrail( + guardrail_function=get_async_output_guardrail(triggers=True, output_info="test_info") + ) + result = await guardrail_3.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "test_info" + + +@pytest.mark.asyncio +async def test_invalid_tool_input_guardrail_raises_user_error(): + """Test that invalid guardrail functions raise UserError.""" + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail: ToolInputGuardrail[Any] = ToolInputGuardrail(guardrail_function="foo") # type: ignore + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + await guardrail.run(data) + + +@pytest.mark.asyncio +async def test_invalid_tool_output_guardrail_raises_user_error(): + """Test that invalid guardrail functions raise UserError.""" + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail: ToolOutputGuardrail[Any] = ToolOutputGuardrail(guardrail_function="foo") # type: ignore + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="test output", + ) + await guardrail.run(data) + + +# Test decorators + + +@tool_input_guardrail +def decorated_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="test_1") + + +@tool_input_guardrail(name="Custom input name") +def decorated_named_input_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="test_2") + + +@pytest.mark.asyncio +async def test_tool_input_guardrail_decorators(): + """Test input guardrail decorators.""" + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + + # Test basic decorator + guardrail = decorated_input_guardrail + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "test_1" + + # Test named decorator + guardrail = decorated_named_input_guardrail + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "test_2" + assert guardrail.get_name() == "Custom input name" + + +@tool_output_guardrail +def decorated_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="test_3") + + +@tool_output_guardrail(name="Custom output name") +def decorated_named_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="test_4") + + +@pytest.mark.asyncio +async def test_tool_output_guardrail_decorators(): + """Test output guardrail decorators.""" + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="test output", + ) + + # Test basic decorator + guardrail = decorated_output_guardrail + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "test_3" + + # Test named decorator + guardrail = decorated_named_output_guardrail + result = await guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "test_4" + assert guardrail.get_name() == "Custom output name" + + +# Test practical examples + + +@pytest.mark.asyncio +async def test_password_blocking_input_guardrail(): + """Test a realistic input guardrail that blocks passwords.""" + + @tool_input_guardrail + def check_for_password(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + if "password" in data.context.tool_arguments.lower(): + return ToolGuardrailFunctionOutput.reject_content( + message="Tool call blocked: contains password", + output_info={"blocked_word": "password"}, + ) + return ToolGuardrailFunctionOutput(output_info="safe_input") + + # Test with password - should trigger + data = ToolInputGuardrailData( + context=get_mock_tool_context('{"message": "Hello password world"}'), + agent=Agent(name="test"), + ) + result = await check_for_password.run(data) + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Tool call blocked: contains password" + assert result.output_info["blocked_word"] == "password" + + # Test without password - should pass + data = ToolInputGuardrailData( + context=get_mock_tool_context('{"message": "Hello safe world"}'), + agent=Agent(name="test"), + ) + result = await check_for_password.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "safe_input" + + +@pytest.mark.asyncio +async def test_ssn_blocking_output_guardrail(): + """Test a realistic output guardrail that blocks SSNs.""" + + @tool_output_guardrail + def check_for_ssn(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + output_str = str(data.output).lower() + if "ssn" in output_str or "123-45-6789" in output_str: + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"blocked_pattern": "SSN"} + ) + return ToolGuardrailFunctionOutput(output_info="safe_output") + + # Test with SSN in output - should trigger + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="User SSN is 123-45-6789", + ) + result = await check_for_ssn.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info["blocked_pattern"] == "SSN" + + # Test with safe output - should pass + data = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="User name is John Doe", + ) + result = await check_for_ssn.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "safe_output" + + +def test_tool_input_guardrail_exception(): + """Test the tool input guardrail tripwire exception.""" + + @tool_input_guardrail + def test_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.raise_exception(output_info="test") + + output = ToolGuardrailFunctionOutput.raise_exception(output_info="test") + + exception = ToolInputGuardrailTripwireTriggered( + guardrail=test_guardrail, + output=output, + ) + + assert exception.guardrail == test_guardrail + assert exception.output == output + assert "ToolInputGuardrail" in str(exception) + + +def test_tool_output_guardrail_exception(): + """Test the tool output guardrail tripwire exception.""" + + @tool_output_guardrail + def test_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.raise_exception(output_info="test") + + output = ToolGuardrailFunctionOutput.raise_exception(output_info="test") + + exception = ToolOutputGuardrailTripwireTriggered( + guardrail=test_guardrail, + output=output, + ) + + assert exception.guardrail == test_guardrail + assert exception.output == output + assert "ToolOutputGuardrail" in str(exception) + + +# Test new behavior system + + +@pytest.mark.asyncio +async def test_allow_behavior(): + """Test the allow behavior type.""" + + @tool_input_guardrail + def allow_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="allowed") + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await allow_guardrail.run(data) + assert result.behavior["type"] == "allow" + assert result.output_info == "allowed" + + +@pytest.mark.asyncio +async def test_reject_content_behavior(): + """Test the reject_content behavior type.""" + + @tool_input_guardrail + def reject_content_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.reject_content( + message="Tool blocked by guardrail", output_info="rejected" + ) + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await reject_content_guardrail.run(data) + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Tool blocked by guardrail" + assert result.output_info == "rejected" + + +@pytest.mark.asyncio +async def test_raise_exception_behavior(): + """Test the raise_exception behavior type.""" + + @tool_input_guardrail + def raise_exception_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.raise_exception(output_info="exception") + + data = ToolInputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + ) + result = await raise_exception_guardrail.run(data) + assert result.behavior["type"] == "raise_exception" + assert result.output_info == "exception" + + +@pytest.mark.asyncio +async def test_mixed_behavior_output_guardrail(): + """Test mixing different behavior types in output guardrails.""" + + @tool_output_guardrail + def mixed_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + output_str = str(data.output).lower() + if "dangerous" in output_str: + return ToolGuardrailFunctionOutput.raise_exception( + output_info={"reason": "dangerous_content"} + ) + elif "sensitive" in output_str: + return ToolGuardrailFunctionOutput.reject_content( + message="Content was filtered", output_info={"reason": "sensitive_content"} + ) + else: + return ToolGuardrailFunctionOutput(output_info={"status": "clean"}) + + # Test dangerous content (should raise exception) + data_dangerous = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is dangerous content", + ) + result = await mixed_guardrail.run(data_dangerous) + assert result.behavior["type"] == "raise_exception" + assert result.output_info["reason"] == "dangerous_content" + + # Test sensitive content (should reject content) + data_sensitive = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is sensitive data", + ) + result = await mixed_guardrail.run(data_sensitive) + assert result.behavior["type"] == "reject_content" + assert result.behavior["message"] == "Content was filtered" + assert result.output_info["reason"] == "sensitive_content" + + # Test clean content (should allow) + data_clean = ToolOutputGuardrailData( + context=get_mock_tool_context(), + agent=Agent(name="test"), + output="This is clean content", + ) + result = await mixed_guardrail.run(data_clean) + assert result.behavior["type"] == "allow" + assert result.output_info["status"] == "clean" + + +if __name__ == "__main__": + # Run a simple test to verify functionality + async def main(): + print("Testing tool guardrails...") + + @tool_input_guardrail + def test_guard(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + return ToolGuardrailFunctionOutput.allow(output_info="test_passed") + + print(f"✅ Created guardrail: {test_guard.get_name()}") + print("✅ All basic tests passed!") + + asyncio.run(main())