Skip to content

Commit 2615892

Browse files
authored
Return Guardrail token usage (#62)
* Returning token usage by Guardrails * Fix AttributionError on total_token_usage
1 parent 0ec56d3 commit 2615892

27 files changed

+1203
-151
lines changed

docs/agents_sdk_integration.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,52 @@ from guardrails import JsonString
8181
agent = GuardrailAgent(config=JsonString('{"version": 1, ...}'), ...)
8282
```
8383

84+
## Token Usage Tracking
85+
86+
Track token usage from LLM-based guardrails using the unified `total_guardrail_token_usage` function:
87+
88+
```python
89+
from guardrails import GuardrailAgent, total_guardrail_token_usage
90+
from agents import Runner
91+
92+
agent = GuardrailAgent(config="config.json", name="Assistant", instructions="...")
93+
result = await Runner.run(agent, "Hello")
94+
95+
# Get aggregated token usage from all guardrails
96+
tokens = total_guardrail_token_usage(result)
97+
print(f"Guardrail tokens used: {tokens['total_tokens']}")
98+
```
99+
100+
### Per-Stage Token Usage
101+
102+
For per-stage token usage, access the guardrail results directly on the `RunResult`:
103+
104+
```python
105+
# Input guardrails (agent-level)
106+
for gr in result.input_guardrail_results:
107+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
108+
if usage:
109+
print(f"Input guardrail: {usage['total_tokens']} tokens")
110+
111+
# Output guardrails (agent-level)
112+
for gr in result.output_guardrail_results:
113+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
114+
if usage:
115+
print(f"Output guardrail: {usage['total_tokens']} tokens")
116+
117+
# Tool input guardrails (per-tool)
118+
for gr in result.tool_input_guardrail_results:
119+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
120+
if usage:
121+
print(f"Tool input guardrail: {usage['total_tokens']} tokens")
122+
123+
# Tool output guardrails (per-tool)
124+
for gr in result.tool_output_guardrail_results:
125+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
126+
if usage:
127+
print(f"Tool output guardrail: {usage['total_tokens']} tokens")
128+
```
129+
84130
## Next Steps
85131

86132
- Use the [Guardrails Wizard](https://guardrails.openai.com/) to generate your configuration

docs/quickstart.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,87 @@ client = GuardrailsAsyncOpenAI(
203203
)
204204
```
205205

206+
## Token Usage Tracking
207+
208+
LLM-based guardrails (Jailbreak, Custom Prompt Check, etc.) consume tokens. You can track token usage across all guardrail calls using the unified `total_guardrail_token_usage` function:
209+
210+
```python
211+
from guardrails import GuardrailsAsyncOpenAI, total_guardrail_token_usage
212+
213+
client = GuardrailsAsyncOpenAI(config="config.json")
214+
response = await client.responses.create(model="gpt-4o", input="Hello")
215+
216+
# Get aggregated token usage from all guardrails
217+
tokens = total_guardrail_token_usage(response)
218+
print(f"Guardrail tokens used: {tokens['total_tokens']}")
219+
# Output: Guardrail tokens used: 425
220+
```
221+
222+
The function returns a dictionary:
223+
```python
224+
{
225+
"prompt_tokens": 300, # Sum of prompt tokens across all LLM guardrails
226+
"completion_tokens": 125, # Sum of completion tokens
227+
"total_tokens": 425, # Total tokens used by guardrails
228+
}
229+
```
230+
231+
### Works Across All Surfaces
232+
233+
`total_guardrail_token_usage` works with any guardrails result type:
234+
235+
```python
236+
# OpenAI client responses
237+
response = await client.responses.create(...)
238+
tokens = total_guardrail_token_usage(response)
239+
240+
# Streaming (use the last chunk)
241+
async for chunk in stream:
242+
last_chunk = chunk
243+
tokens = total_guardrail_token_usage(last_chunk)
244+
245+
# Agents SDK
246+
result = await Runner.run(agent, input)
247+
tokens = total_guardrail_token_usage(result)
248+
```
249+
250+
### Per-Guardrail Token Usage
251+
252+
Each guardrail result includes its own token usage in the `info` dict:
253+
254+
**OpenAI Clients (GuardrailsAsyncOpenAI, etc.)**:
255+
256+
```python
257+
response = await client.responses.create(model="gpt-4.1", input="Hello")
258+
259+
for gr in response.guardrail_results.all_results:
260+
usage = gr.info.get("token_usage")
261+
if usage:
262+
print(f"{gr.info['guardrail_name']}: {usage['total_tokens']} tokens")
263+
```
264+
265+
**Agents SDK** - access token usage per stage via `RunResult`:
266+
267+
```python
268+
result = await Runner.run(agent, "Hello")
269+
270+
# Input guardrails
271+
for gr in result.input_guardrail_results:
272+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
273+
if usage:
274+
print(f"Input: {usage['total_tokens']} tokens")
275+
276+
# Output guardrails
277+
for gr in result.output_guardrail_results:
278+
usage = gr.output.output_info.get("token_usage") if gr.output.output_info else None
279+
if usage:
280+
print(f"Output: {usage['total_tokens']} tokens")
281+
282+
# Tool guardrails: result.tool_input_guardrail_results, result.tool_output_guardrail_results
283+
```
284+
285+
Non-LLM guardrails (URL Filter, Moderation, PII) don't consume tokens and won't have `token_usage` in their info.
286+
206287
## Next Steps
207288

208289
- Explore [examples](./examples.md) for advanced patterns

examples/basic/hello_world.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,24 @@
66
from rich.console import Console
77
from rich.panel import Panel
88

9-
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
9+
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage
1010

1111
console = Console()
1212

13-
# Pipeline configuration with pre_flight and input guardrails
13+
# Define your pipeline configuration
1414
PIPELINE_CONFIG = {
1515
"version": 1,
1616
"pre_flight": {
1717
"version": 1,
1818
"guardrails": [
19-
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
19+
{"name": "Moderation", "config": {"categories": ["hate", "violence"]}},
20+
{
21+
"name": "Jailbreak",
22+
"config": {
23+
"model": "gpt-4.1-mini",
24+
"confidence_threshold": 0.7,
25+
},
26+
},
2027
],
2128
},
2229
"input": {
@@ -52,6 +59,9 @@ async def process_input(
5259
# Show guardrail results if any were run
5360
if response.guardrail_results.all_results:
5461
console.print(f"[dim]Guardrails checked: {len(response.guardrail_results.all_results)}[/dim]")
62+
# Use unified function - works with any guardrails response type
63+
tokens = total_guardrail_token_usage(response)
64+
console.print(f"[dim]Token usage: {tokens}[/dim]")
5565

5666
return response.id
5767

examples/basic/local_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from rich.console import Console
88
from rich.panel import Panel
99

10-
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
10+
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage
1111

1212
console = Console()
1313

@@ -50,6 +50,7 @@ async def process_input(
5050
# Access response content using standard OpenAI API
5151
response_content = response.choices[0].message.content
5252
console.print(f"\nAssistant output: {response_content}", end="\n\n")
53+
console.print(f"Token usage: {total_guardrail_token_usage(response)}")
5354

5455
# Add to conversation history
5556
input_data.append({"role": "user", "content": user_input})

examples/basic/multi_bundle.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from rich.live import Live
88
from rich.panel import Panel
99

10-
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered
10+
from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered, total_guardrail_token_usage
1111

1212
console = Console()
1313

@@ -22,6 +22,13 @@
2222
"name": "URL Filter",
2323
"config": {"url_allow_list": ["example.com", "baz.com"]},
2424
},
25+
{
26+
"name": "Jailbreak",
27+
"config": {
28+
"model": "gpt-4.1-mini",
29+
"confidence_threshold": 0.7,
30+
},
31+
},
2532
],
2633
},
2734
"input": {
@@ -63,19 +70,26 @@ async def process_input(
6370

6471
# Stream the assistant's output inside a Rich Live panel
6572
output_text = "Assistant output: "
73+
last_chunk = None
6674
with Live(output_text, console=console, refresh_per_second=10) as live:
6775
try:
6876
async for chunk in stream:
77+
last_chunk = chunk
6978
# Access streaming response exactly like native OpenAI API (flattened)
7079
if hasattr(chunk, "delta") and chunk.delta:
7180
output_text += chunk.delta
7281
live.update(output_text)
7382

7483
# Get the response ID from the final chunk
7584
response_id_to_return = None
76-
if hasattr(chunk, "response") and hasattr(chunk.response, "id"):
77-
response_id_to_return = chunk.response.id
78-
85+
if last_chunk and hasattr(last_chunk, "response") and hasattr(last_chunk.response, "id"):
86+
response_id_to_return = last_chunk.response.id
87+
88+
# Print token usage from guardrail results (unified interface)
89+
if last_chunk:
90+
tokens = total_guardrail_token_usage(last_chunk)
91+
if tokens["total_tokens"]:
92+
console.print(f"[dim]📊 Guardrail tokens: {tokens['total_tokens']}[/dim]")
7993
return response_id_to_return
8094

8195
except GuardrailTripwireTriggered:

src/guardrails/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
run_guardrails,
4141
)
4242
from .spec import GuardrailSpecMetadata
43-
from .types import GuardrailResult
43+
from .types import GuardrailResult, total_guardrail_token_usage
4444

4545
__all__ = [
4646
"ConfiguredGuardrail", # configured, executable object
@@ -64,6 +64,7 @@
6464
"load_pipeline_bundles",
6565
"default_spec_registry",
6666
"resources", # resource modules
67+
"total_guardrail_token_usage", # unified token usage aggregation
6768
]
6869

6970
__version__: str = _m.version("openai-guardrails")

src/guardrails/_base_client.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .context import has_context
2121
from .runtime import load_pipeline_bundles
22-
from .types import GuardrailLLMContextProto, GuardrailResult
22+
from .types import GuardrailLLMContextProto, GuardrailResult, aggregate_token_usage_from_infos
2323
from .utils.context import validate_guardrail_context
2424
from .utils.conversation import append_assistant_response, normalize_conversation
2525

@@ -77,6 +77,23 @@ def triggered_results(self) -> list[GuardrailResult]:
7777
"""Get only the guardrail results that triggered tripwires."""
7878
return [r for r in self.all_results if r.tripwire_triggered]
7979

80+
@property
81+
def total_token_usage(self) -> dict[str, Any]:
82+
"""Aggregate token usage across all LLM-based guardrails.
83+
84+
Sums prompt_tokens, completion_tokens, and total_tokens from all
85+
guardrail results that include token_usage in their info dict.
86+
Non-LLM guardrails (which don't have token_usage) are skipped.
87+
88+
Returns:
89+
Dictionary with:
90+
- prompt_tokens: Sum of all prompt tokens (or None if no data)
91+
- completion_tokens: Sum of all completion tokens (or None if no data)
92+
- total_tokens: Sum of all total tokens (or None if no data)
93+
"""
94+
infos = (result.info for result in self.all_results)
95+
return aggregate_token_usage_from_infos(infos)
96+
8097

8198
@dataclass(frozen=True, slots=True, weakref_slot=True)
8299
class GuardrailsResponse:
@@ -427,8 +444,7 @@ def _mask_text(text: str) -> str:
427444
or (
428445
len(candidate_lower) >= 3
429446
and any( # Any 3-char chunk overlaps
430-
candidate_lower[i : i + 3] in detected_lower
431-
for i in range(len(candidate_lower) - 2)
447+
candidate_lower[i : i + 3] in detected_lower for i in range(len(candidate_lower) - 2)
432448
)
433449
)
434450
)
@@ -459,13 +475,7 @@ def _mask_text(text: str) -> str:
459475
modified_content.append(part)
460476
else:
461477
# Handle object-based content parts
462-
if (
463-
hasattr(part, "type")
464-
and hasattr(part, "text")
465-
and part.type in _TEXT_CONTENT_TYPES
466-
and isinstance(part.text, str)
467-
and part.text
468-
):
478+
if hasattr(part, "type") and hasattr(part, "text") and part.type in _TEXT_CONTENT_TYPES and isinstance(part.text, str) and part.text:
469479
try:
470480
part.text = _mask_text(part.text)
471481
except Exception:

0 commit comments

Comments
 (0)