Skip to content

Commit 198c8bd

Browse files
committed
Extract nested tool guardrail logic in _run_impl
1 parent 55c949a commit 198c8bd

File tree

3 files changed

+186
-76
lines changed

3 files changed

+186
-76
lines changed

examples/basic/tool_guardrails.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_user_data(user_id: str) -> dict[str, str]:
3232
"phone": "555-1234",
3333
}
3434

35+
3536
@function_tool
3637
def get_contact_info(user_id: str) -> dict[str, str]:
3738
"""Get contact info by ID."""
@@ -122,7 +123,9 @@ async def main():
122123

123124
# Example 2: Input guardrail triggers - function tool call is rejected but execution continues
124125
print("2. Attempting to send email with suspicious content:")
125-
result = await Runner.run(agent, "Send an email to [email protected] introducing the company ACME corp.")
126+
result = await Runner.run(
127+
agent, "Send an email to [email protected] introducing the company ACME corp."
128+
)
126129
print(f"❌ Guardrail rejected function tool call: {result.final_output}\n")
127130
except Exception as e:
128131
print(f"Error: {e}\n")
@@ -136,7 +139,6 @@ async def main():
136139
print("🚨 Output guardrail triggered: Execution halted for sensitive data")
137140
print(f"Details: {e.output.output_info}\n")
138141

139-
140142
try:
141143
# Example 4: Output guardrail triggers - reject returning function tool output but continue execution
142144
print("4. Rejecting function tool output containing phone numbers:")
@@ -145,6 +147,7 @@ async def main():
145147
except Exception as e:
146148
print(f"Error: {e}\n")
147149

150+
148151
if __name__ == "__main__":
149152
asyncio.run(main())
150153

@@ -157,12 +160,12 @@ async def main():
157160
✅ Successful tool execution: I've sent a welcome email to [email protected] with an appropriate subject and greeting message.
158161
159162
2. Attempting to send email with suspicious content:
160-
❌ Guardrail rejected function tool call: I'm unable to send the email mentioning ACME Corp as it was blocked by security guardrails.
163+
❌ Guardrail rejected function tool call: I'm unable to send the email as mentioning ACME Corp. is restricted.
161164
162165
3. Attempting to get user data (contains SSN). Execution blocked:
163166
🚨 Output guardrail triggered: Execution halted for sensitive data
164167
Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'}
165168
166169
4. Rejecting function tool output containing sensitive data:
167-
✅ Successful tool execution: User data retrieved (phone number redacted for privacy)
170+
❌ Guardrail rejected function tool output: I'm unable to retrieve the contact info for user456 because it contains restricted information.
168171
"""

src/agents/_run_impl.py

Lines changed: 177 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ async def execute_tools_and_side_effects(
341341
final_output=check_tool_use.final_output,
342342
hooks=hooks,
343343
context_wrapper=context_wrapper,
344+
tool_input_guardrail_results=tool_input_guardrail_results,
345+
tool_output_guardrail_results=tool_output_guardrail_results,
344346
)
345347

346348
# Now we can check if the model also produced a final output
@@ -574,6 +576,155 @@ def process_model_response(
574576
mcp_approval_requests=mcp_approval_requests,
575577
)
576578

579+
@classmethod
580+
async def _execute_input_guardrails(
581+
cls,
582+
*,
583+
func_tool: FunctionTool,
584+
tool_context: ToolContext[TContext],
585+
agent: Agent[TContext],
586+
tool_input_guardrail_results: list[ToolInputGuardrailResult],
587+
) -> str | None:
588+
"""Execute input guardrails for a tool.
589+
590+
Args:
591+
func_tool: The function tool being executed.
592+
tool_context: The tool execution context.
593+
agent: The agent executing the tool.
594+
tool_input_guardrail_results: List to append guardrail results to.
595+
596+
Returns:
597+
None if tool execution should proceed, or a message string if execution should be
598+
skipped.
599+
600+
Raises:
601+
ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception.
602+
"""
603+
if not func_tool.tool_input_guardrails:
604+
return None
605+
606+
for guardrail in func_tool.tool_input_guardrails:
607+
gr_out = await guardrail.run(
608+
ToolInputGuardrailData(
609+
context=tool_context,
610+
agent=agent,
611+
)
612+
)
613+
614+
# Store the guardrail result
615+
tool_input_guardrail_results.append(
616+
ToolInputGuardrailResult(
617+
guardrail=guardrail,
618+
output=gr_out,
619+
)
620+
)
621+
622+
# Handle different behavior types
623+
if gr_out.behavior["type"] == "raise_exception":
624+
raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out)
625+
elif gr_out.behavior["type"] == "reject_content":
626+
# Set final_result to the message and skip tool execution
627+
return gr_out.behavior["message"]
628+
elif gr_out.behavior["type"] == "allow":
629+
# Continue to next guardrail or tool execution
630+
continue
631+
632+
return None
633+
634+
@classmethod
635+
async def _execute_output_guardrails(
636+
cls,
637+
*,
638+
func_tool: FunctionTool,
639+
tool_context: ToolContext[TContext],
640+
agent: Agent[TContext],
641+
real_result: Any,
642+
tool_output_guardrail_results: list[ToolOutputGuardrailResult],
643+
) -> Any:
644+
"""Execute output guardrails for a tool.
645+
646+
Args:
647+
func_tool: The function tool being executed.
648+
tool_context: The tool execution context.
649+
agent: The agent executing the tool.
650+
real_result: The actual result from the tool execution.
651+
tool_output_guardrail_results: List to append guardrail results to.
652+
653+
Returns:
654+
The final result after guardrail processing (may be modified).
655+
656+
Raises:
657+
ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception.
658+
"""
659+
if not func_tool.tool_output_guardrails:
660+
return real_result
661+
662+
final_result = real_result
663+
for output_guardrail in func_tool.tool_output_guardrails:
664+
gr_out = await output_guardrail.run(
665+
ToolOutputGuardrailData(
666+
context=tool_context,
667+
agent=agent,
668+
output=real_result,
669+
)
670+
)
671+
672+
# Store the guardrail result
673+
tool_output_guardrail_results.append(
674+
ToolOutputGuardrailResult(
675+
guardrail=output_guardrail,
676+
output=gr_out,
677+
)
678+
)
679+
680+
# Handle different behavior types
681+
if gr_out.behavior["type"] == "raise_exception":
682+
raise ToolOutputGuardrailTripwireTriggered(
683+
guardrail=output_guardrail, output=gr_out
684+
)
685+
elif gr_out.behavior["type"] == "reject_content":
686+
# Override the result with the guardrail message
687+
final_result = gr_out.behavior["message"]
688+
break
689+
elif gr_out.behavior["type"] == "allow":
690+
# Continue to next guardrail
691+
continue
692+
693+
return final_result
694+
695+
@classmethod
696+
async def _execute_tool_with_hooks(
697+
cls,
698+
*,
699+
func_tool: FunctionTool,
700+
tool_context: ToolContext[TContext],
701+
agent: Agent[TContext],
702+
hooks: RunHooks[TContext],
703+
tool_call: ResponseFunctionToolCall,
704+
) -> Any:
705+
"""Execute the core tool function with before/after hooks.
706+
707+
Args:
708+
func_tool: The function tool being executed.
709+
tool_context: The tool execution context.
710+
agent: The agent executing the tool.
711+
hooks: The run hooks to execute.
712+
tool_call: The tool call details.
713+
714+
Returns:
715+
The result from the tool execution.
716+
"""
717+
await asyncio.gather(
718+
hooks.on_tool_start(tool_context, agent, func_tool),
719+
(
720+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
721+
if agent.hooks
722+
else _coro.noop_coroutine()
723+
),
724+
)
725+
726+
return await func_tool.on_invoke_tool(tool_context, tool_call.arguments)
727+
577728
@classmethod
578729
async def execute_function_tool_calls(
579730
cls,
@@ -603,83 +754,35 @@ async def run_single_tool(
603754
span_fn.span_data.input = tool_call.arguments
604755
try:
605756
# 1) Run input tool guardrails, if any
606-
final_result: Any | None = None
607-
if func_tool.tool_input_guardrails:
608-
for guardrail in func_tool.tool_input_guardrails:
609-
gr_out = await guardrail.run(
610-
ToolInputGuardrailData(
611-
context=tool_context,
612-
agent=agent,
613-
)
614-
)
615-
616-
# Store the guardrail result
617-
tool_input_guardrail_results.append(
618-
ToolInputGuardrailResult(
619-
guardrail=guardrail,
620-
output=gr_out,
621-
)
622-
)
623-
624-
# Handle different behavior types
625-
if gr_out.behavior["type"] == "raise_exception":
626-
raise ToolInputGuardrailTripwireTriggered(
627-
guardrail=guardrail, output=gr_out
628-
)
629-
elif gr_out.behavior["type"] == "reject_content":
630-
# Set final_result to the message and skip tool execution
631-
final_result = gr_out.behavior["message"]
632-
break
633-
elif gr_out.behavior["type"] == "allow":
634-
# Continue to next guardrail or tool execution
635-
continue
636-
637-
if final_result is None:
757+
rejected_message = await cls._execute_input_guardrails(
758+
func_tool=func_tool,
759+
tool_context=tool_context,
760+
agent=agent,
761+
tool_input_guardrail_results=tool_input_guardrail_results,
762+
)
763+
764+
if rejected_message is not None:
765+
# Input guardrail rejected the tool call
766+
final_result = rejected_message
767+
else:
638768
# 2) Actually run the tool
639-
await asyncio.gather(
640-
hooks.on_tool_start(tool_context, agent, func_tool),
641-
(
642-
agent.hooks.on_tool_start(tool_context, agent, func_tool)
643-
if agent.hooks
644-
else _coro.noop_coroutine()
645-
),
646-
)
647-
real_result = await func_tool.on_invoke_tool(
648-
tool_context, tool_call.arguments
769+
real_result = await cls._execute_tool_with_hooks(
770+
func_tool=func_tool,
771+
tool_context=tool_context,
772+
agent=agent,
773+
hooks=hooks,
774+
tool_call=tool_call,
649775
)
650776

651777
# 3) Run output tool guardrails, if any
652-
final_result = real_result
653-
if func_tool.tool_output_guardrails:
654-
for output_guardrail in func_tool.tool_output_guardrails:
655-
gr_out = await output_guardrail.run(
656-
ToolOutputGuardrailData(
657-
context=tool_context,
658-
agent=agent,
659-
output=real_result,
660-
)
661-
)
778+
final_result = await cls._execute_output_guardrails(
779+
func_tool=func_tool,
780+
tool_context=tool_context,
781+
agent=agent,
782+
real_result=real_result,
783+
tool_output_guardrail_results=tool_output_guardrail_results,
784+
)
662785

663-
# Store the guardrail result
664-
tool_output_guardrail_results.append(
665-
ToolOutputGuardrailResult(
666-
guardrail=output_guardrail,
667-
output=gr_out,
668-
)
669-
)
670-
671-
# Handle different behavior types
672-
if gr_out.behavior["type"] == "raise_exception":
673-
raise ToolOutputGuardrailTripwireTriggered(
674-
guardrail=output_guardrail, output=gr_out
675-
)
676-
elif gr_out.behavior["type"] == "reject_content":
677-
# Override the result with the guardrail message
678-
final_result = gr_out.behavior["message"]
679-
break
680-
elif gr_out.behavior["type"] == "allow":
681-
# Continue to next guardrail
682-
continue
683786
# 4) Tool end hooks (with final result, which may have been overridden)
684787
await asyncio.gather(
685788
hooks.on_tool_end(tool_context, agent, func_tool, final_result),
@@ -932,6 +1035,8 @@ async def execute_handoffs(
9321035
pre_step_items=pre_step_items,
9331036
new_step_items=new_step_items,
9341037
next_step=NextStepHandoff(new_agent),
1038+
tool_input_guardrail_results=[],
1039+
tool_output_guardrail_results=[],
9351040
)
9361041

9371042
@classmethod

src/agents/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,8 @@ def run_streamed(
715715
max_turns=max_turns,
716716
input_guardrail_results=[],
717717
output_guardrail_results=[],
718+
tool_input_guardrail_results=[],
719+
tool_output_guardrail_results=[],
718720
_current_agent_output_schema=output_schema,
719721
trace=new_trace,
720722
context_wrapper=context_wrapper,

0 commit comments

Comments
 (0)