Skip to content

Commit 03962fa

Browse files
committed
Adding tool input and output guardrails
1 parent a4c125e commit 03962fa

File tree

11 files changed

+1049
-18
lines changed

11 files changed

+1049
-18
lines changed

docs/guardrails.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,83 @@ async def main():
152152
2. This is the guardrail's output type.
153153
3. This is the guardrail function that receives the agent's output, and returns the result.
154154
4. This is the actual agent that defines the workflow.
155+
156+
## Tool guardrails
157+
158+
Tool guardrails provide fine-grained control over individual tool calls, allowing you to validate inputs and outputs at the tool level. This is particularly useful for:
159+
160+
- Blocking sensitive data in tool arguments
161+
- Preventing unauthorized access to certain tools
162+
- Sanitizing tool outputs before they're returned
163+
- Implementing custom validation logic for specific tools
164+
165+
There are two types of tool guardrails:
166+
167+
1. **Tool input guardrails** run before a tool is executed, validating the tool call arguments
168+
2. **Tool output guardrails** run after a tool is executed, validating the tool's output
169+
170+
### Tool input guardrails
171+
172+
Tool input guardrails run in 3 steps:
173+
174+
1. First, the guardrail receives the tool call data including arguments, context, and agent information
175+
2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput]
176+
3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolInputGuardrailTripwireTriggered`][agents.exceptions.ToolInputGuardrailTripwireTriggered] exception is raised
177+
178+
### Tool output guardrails
179+
180+
Tool output guardrails run in 3 steps:
181+
182+
1. First, the guardrail receives the tool call data plus the tool's output
183+
2. Next, the guardrail function runs to produce a [`ToolGuardrailFunctionOutput`][agents.tool_guardrails.ToolGuardrailFunctionOutput]
184+
3. Finally, we check if [`.tripwire_triggered`][agents.tool_guardrails.ToolGuardrailFunctionOutput.tripwire_triggered] is true. If true, a [`ToolOutputGuardrailTripwireTriggered`][agents.exceptions.ToolOutputGuardrailTripwireTriggered] exception is raised
185+
186+
### Implementing tool guardrails
187+
188+
You can create tool guardrails using the `@tool_input_guardrail` and `@tool_output_guardrail` decorators:
189+
190+
```python
191+
from agents import (
192+
ToolGuardrailFunctionOutput,
193+
ToolInputGuardrailData,
194+
ToolOutputGuardrailData,
195+
tool_input_guardrail,
196+
tool_output_guardrail,
197+
)
198+
199+
@tool_input_guardrail
200+
def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
201+
"""Block tool calls that contain sensitive words in arguments."""
202+
# Check arguments for sensitive content
203+
if "password" in data.tool_call.arguments.lower():
204+
return ToolGuardrailFunctionOutput(
205+
tripwire_triggered=True,
206+
model_message="🚨 Tool call blocked: contains sensitive word",
207+
output_info={"blocked_word": "password"},
208+
)
209+
return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated")
210+
211+
@tool_output_guardrail
212+
def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
213+
"""Block tool outputs that contain sensitive data."""
214+
if "ssn" in str(data.output).lower():
215+
return ToolGuardrailFunctionOutput(
216+
tripwire_triggered=True,
217+
model_message="🚨 Tool output blocked: contains sensitive data",
218+
output_info={"blocked_pattern": "SSN"},
219+
)
220+
return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated")
221+
222+
# Apply guardrails to tools
223+
my_tool.tool_input_guardrails = [block_sensitive_words]
224+
my_tool.tool_output_guardrails = [block_sensitive_output]
225+
```
226+
227+
For a complete working example, see [tool_guardrails.py](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tool_guardrails.py).
228+
229+
### Key differences from agent guardrails
230+
231+
- **Scope**: Tool guardrails operate on individual tool calls, while agent guardrails operate on the entire agent input/output
232+
- **Timing**: Tool guardrails run immediately before/after tool execution, while agent guardrails run at the beginning/end of agent execution
233+
- **Data access**: Tool guardrails have access to the specific tool call arguments and outputs, plus the tool context
234+
- **Application**: Tool guardrails are applied directly to function tools via the `tool_input_guardrails` and `tool_output_guardrails` attributes

docs/ref/tool_guardrails.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `Tool Guardrails`
2+
3+
::: agents.tool_guardrails

examples/basic/tool_guardrails.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import asyncio
2+
import json
3+
4+
from agents import (
5+
Agent,
6+
Runner,
7+
ToolGuardrailFunctionOutput,
8+
ToolInputGuardrailData,
9+
ToolInputGuardrailTripwireTriggered,
10+
ToolOutputGuardrailData,
11+
ToolOutputGuardrailTripwireTriggered,
12+
function_tool,
13+
tool_input_guardrail,
14+
tool_output_guardrail,
15+
)
16+
17+
18+
@function_tool
19+
def send_email(to: str, subject: str, body: str) -> str:
20+
"""Send an email to the specified recipient."""
21+
return f"Email sent to {to} with subject '{subject}'"
22+
23+
24+
@function_tool
25+
def get_user_data(user_id: str) -> dict[str, str]:
26+
"""Get user data by ID."""
27+
# Simulate returning sensitive data
28+
return {
29+
"user_id": user_id,
30+
"name": "John Doe",
31+
"email": "[email protected]",
32+
"ssn": "123-45-6789", # Sensitive data that should be blocked!
33+
"phone": "555-1234",
34+
}
35+
36+
37+
@tool_input_guardrail
38+
def block_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
39+
"""Block tool calls that contain sensitive words in arguments."""
40+
try:
41+
args = json.loads(data.tool_call.arguments)
42+
except json.JSONDecodeError:
43+
return ToolGuardrailFunctionOutput(
44+
tripwire_triggered=False, output_info="Invalid JSON arguments"
45+
)
46+
47+
# Check for suspicious content
48+
sensitive_words = [
49+
"password",
50+
"hack",
51+
"exploit",
52+
"malware",
53+
"orange",
54+
] # to mock sensitive words
55+
for key, value in args.items():
56+
value_str = str(value).lower()
57+
for word in sensitive_words:
58+
if word in value_str:
59+
return ToolGuardrailFunctionOutput(
60+
tripwire_triggered=True,
61+
model_message=f"🚨 Tool call blocked: contains '{word}'",
62+
output_info={"blocked_word": word, "argument": key},
63+
)
64+
65+
return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Input validated")
66+
67+
68+
@tool_output_guardrail
69+
def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
70+
"""Block tool outputs that contain sensitive data."""
71+
output_str = str(data.output).lower()
72+
73+
# Check for sensitive data patterns
74+
if "ssn" in output_str or "123-45-6789" in output_str:
75+
return ToolGuardrailFunctionOutput(
76+
tripwire_triggered=True,
77+
model_message="🚨 Tool output blocked: contains sensitive data",
78+
output_info={"blocked_pattern": "SSN", "tool": data.tool_call.name},
79+
)
80+
81+
return ToolGuardrailFunctionOutput(tripwire_triggered=False, output_info="Output validated")
82+
83+
84+
# Apply guardrails to tools
85+
send_email.tool_input_guardrails = [block_sensitive_words]
86+
get_user_data.tool_output_guardrails = [block_sensitive_output]
87+
88+
agent = Agent(
89+
name="Secure Assistant",
90+
instructions="You are a helpful assistant with access to email and user data tools.",
91+
tools=[send_email, get_user_data],
92+
)
93+
94+
95+
async def main():
96+
print("=== Tool Guardrails Example ===\n")
97+
98+
# Example 1: Normal operation - should work fine
99+
print("1. Normal email sending:")
100+
try:
101+
result = await Runner.run(agent, "Send a welcome email to [email protected]")
102+
print(f"✅ Success: {result.final_output}\n")
103+
except Exception as e:
104+
print(f"❌ Error: {e}\n")
105+
106+
# Example 2: Input guardrail triggers - should block suspicious content
107+
print("2. Attempting to send email with suspicious content:")
108+
try:
109+
result = await Runner.run(
110+
agent, "Send an email to [email protected] with the subject: orange"
111+
)
112+
print(f"✅ Success: {result.final_output}\n")
113+
except ToolInputGuardrailTripwireTriggered as e:
114+
print(f"🚨 Input guardrail triggered: {e.output.model_message}")
115+
print(f" Details: {e.output.output_info}\n")
116+
117+
# Example 3: Output guardrail triggers - should block sensitive data
118+
print("3. Attempting to get user data (contains SSN):")
119+
try:
120+
result = await Runner.run(agent, "Get the data for user ID user123")
121+
print(f"✅ Success: {result.final_output}\n")
122+
except ToolOutputGuardrailTripwireTriggered as e:
123+
print(f"🚨 Output guardrail triggered: {e.output.model_message}")
124+
print(f" Details: {e.output.output_info}\n")
125+
126+
127+
if __name__ == "__main__":
128+
asyncio.run(main())
129+
130+
"""
131+
Example output:
132+
133+
=== Tool Guardrails Example ===
134+
135+
1. Normal email sending:
136+
✅ Success: I've sent a welcome email to [email protected] with an appropriate subject and greeting message.
137+
138+
2. Attempting to send email with suspicious content:
139+
🚨 Input guardrail triggered: 🚨 Tool call blocked: contains 'orange'
140+
Details: {'blocked_word': 'orange', 'argument': 'subject'}
141+
142+
3. Attempting to get user data (contains SSN):
143+
🚨 Output guardrail triggered: 🚨 Tool output blocked: contains sensitive data
144+
Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'}
145+
"""

examples/basic/tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weat
1818
print("[debug] get_weather called")
1919
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
2020

21+
2122
agent = Agent(
2223
name="Hello world",
2324
instructions="You are a helpful agent.",

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ plugins:
101101
- ref/usage.md
102102
- ref/exceptions.md
103103
- ref/guardrail.md
104+
- ref/tool_guardrails.md
104105
- ref/model_settings.md
105106
- ref/agent_output.md
106107
- ref/function_schema.md

src/agents/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
ModelBehaviorError,
2222
OutputGuardrailTripwireTriggered,
2323
RunErrorDetails,
24+
ToolInputGuardrailTripwireTriggered,
25+
ToolOutputGuardrailTripwireTriggered,
2426
UserError,
2527
)
2628
from .guardrail import (
@@ -83,6 +85,15 @@
8385
default_tool_error_function,
8486
function_tool,
8587
)
88+
from .tool_guardrails import (
89+
ToolGuardrailFunctionOutput,
90+
ToolInputGuardrail,
91+
ToolInputGuardrailData,
92+
ToolOutputGuardrail,
93+
ToolOutputGuardrailData,
94+
tool_input_guardrail,
95+
tool_output_guardrail,
96+
)
8697
from .tracing import (
8798
AgentSpanData,
8899
CustomSpanData,
@@ -191,6 +202,8 @@ def enable_verbose_stdout_logging():
191202
"AgentsException",
192203
"InputGuardrailTripwireTriggered",
193204
"OutputGuardrailTripwireTriggered",
205+
"ToolInputGuardrailTripwireTriggered",
206+
"ToolOutputGuardrailTripwireTriggered",
194207
"DynamicPromptFunction",
195208
"GenerateDynamicPromptData",
196209
"Prompt",
@@ -204,6 +217,13 @@ def enable_verbose_stdout_logging():
204217
"GuardrailFunctionOutput",
205218
"input_guardrail",
206219
"output_guardrail",
220+
"ToolInputGuardrail",
221+
"ToolOutputGuardrail",
222+
"ToolGuardrailFunctionOutput",
223+
"ToolInputGuardrailData",
224+
"ToolOutputGuardrailData",
225+
"tool_input_guardrail",
226+
"tool_output_guardrail",
207227
"handoff",
208228
"Handoff",
209229
"HandoffInputData",

0 commit comments

Comments
 (0)