From be849c05db40f3b1a01d635cbc20ef94054b7071 Mon Sep 17 00:00:00 2001 From: Steven C Date: Tue, 2 Dec 2025 12:07:49 -0500 Subject: [PATCH 1/3] Return Guardrail token usage --- docs/agents_sdk_integration.md | 25 ++ docs/quickstart.md | 69 +++++ examples/basic/hello_world.ts | 18 +- examples/basic/local_model.ts | 6 +- ...ltiturn_with_prompt_injection_detection.ts | 24 +- examples/basic/streaming.ts | 23 +- src/__tests__/unit/agents.test.ts | 66 ++++- src/__tests__/unit/base-client.test.ts | 61 +++++ src/__tests__/unit/checks/jailbreak.test.ts | 70 +++-- src/__tests__/unit/llm-base.test.ts | 36 ++- .../unit/prompt_injection_detection.test.ts | 1 + src/__tests__/unit/types.test.ts | 86 +++++- src/base-client.ts | 7 + src/checks/hallucination-detection.ts | 57 ++-- src/checks/jailbreak.ts | 18 +- src/checks/llm-base.ts | 124 +++++---- src/checks/prompt_injection_detection.ts | 61 ++++- src/index.ts | 4 +- src/types.ts | 246 ++++++++++++++++++ 19 files changed, 887 insertions(+), 115 deletions(-) diff --git a/docs/agents_sdk_integration.md b/docs/agents_sdk_integration.md index 442aae5..6a4f055 100644 --- a/docs/agents_sdk_integration.md +++ b/docs/agents_sdk_integration.md @@ -111,3 +111,28 @@ const agent = await GuardrailAgent.create( - Explore available guardrails for your use case - Learn about pipeline configuration in our [quickstart](./quickstart.md) - For more details on the OpenAI Agents SDK, refer to the [Agent SDK documentation](https://openai.github.io/openai-agents-js/). + +## Token Usage Tracking + +!!! warning "JavaScript Agents SDK Limitation" + The JavaScript Agents SDK (`@openai/agents`) does not currently return guardrail results in the `RunResult` object. This means `totalGuardrailTokenUsage()` cannot retrieve token counts from Agents SDK runs. + + **For token usage tracking, use `GuardrailsOpenAI` instead of `GuardrailAgent`.** The Python Agents SDK does support this feature. + +When a guardrail **triggers** (throws `InputGuardrailTripwireTriggered` or `OutputGuardrailTripwireTriggered`), token usage IS available in the error's result object: + +```typescript +try { + const result = await Runner.run(agent, userInput); +} catch (error) { + if (error.constructor.name === 'InputGuardrailTripwireTriggered') { + // Token usage available when guardrail triggers + const usage = error.result?.output?.outputInfo?.token_usage; + if (usage) { + console.log(`Guardrail tokens: ${usage.total_tokens}`); + } + } +} +``` + +For full token usage tracking across all guardrail runs (passing and failing), use the `GuardrailsOpenAI` client instead - see the [quickstart](./quickstart.md#token-usage-tracking) for details. diff --git a/docs/quickstart.md b/docs/quickstart.md index 78505ad..da04b90 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -205,6 +205,75 @@ const client = await GuardrailsOpenAI.create( ); ``` +## Token Usage Tracking + +LLM-based guardrails (Jailbreak, custom prompt checks, etc.) consume tokens. Keep track of those costs with the `totalGuardrailTokenUsage` helper: + +```typescript +import { GuardrailsOpenAI, totalGuardrailTokenUsage } from '@openai/guardrails'; + +const client = await GuardrailsOpenAI.create(CONFIG); +const response = await client.guardrails.responses.create({ + model: 'gpt-4.1-mini', + input: 'Hello!', +}); + +const tokens = totalGuardrailTokenUsage(response); +console.log(`Guardrail tokens used: ${tokens.total_tokens}`); +// => Guardrail tokens used: 425 +``` + +The helper returns: + +```typescript +{ + prompt_tokens: 300, // Sum of prompt tokens across all LLM guardrails + completion_tokens: 125, // Sum of completion tokens + total_tokens: 425, // Total guardrail tokens +} +``` + +### Works With GuardrailsOpenAI Clients + +`totalGuardrailTokenUsage` works across all client types and endpoints: + +- **OpenAI** - sync and async clients +- **Azure OpenAI** - sync and async clients +- **Third-party providers** - any OpenAI-compatible API wrapper +- **Endpoints** - both `responses` and `chat.completions` +- **Streaming** - capture from the final chunk + +```typescript +// OpenAI client responses +const response = await client.guardrails.responses.create(...); +const tokens = totalGuardrailTokenUsage(response); + +// Streaming – use the final chunk +let lastChunk: unknown; +for await (const chunk of stream) { + lastChunk = chunk; +} +const streamingTokens = lastChunk ? totalGuardrailTokenUsage(lastChunk) : null; +``` + +**Note:** The JavaScript Agents SDK (`@openai/agents`) does not currently populate guardrail results in the `RunResult` object, so `totalGuardrailTokenUsage()` will return empty results for Agents SDK runs. + +### Per-Guardrail Usage + +Each guardrail result includes its own `token_usage` entry: + +```typescript +const response = await client.guardrails.responses.create(...); +for (const gr of response.guardrail_results.allResults) { + const usage = gr.info.token_usage; + if (usage) { + console.log(`${gr.info.guardrail_name}: ${usage.total_tokens} tokens`); + } +} +``` + +Non-LLM guardrails (PII, Moderation, URL Filter, etc.) do not consume tokens, so `token_usage` will be omitted. + ## Next Steps - Explore TypeScript [examples](https://github.com/openai/openai-guardrails-js/tree/main/examples) for advanced patterns diff --git a/examples/basic/hello_world.ts b/examples/basic/hello_world.ts index 10af42d..e2f3a22 100644 --- a/examples/basic/hello_world.ts +++ b/examples/basic/hello_world.ts @@ -8,19 +8,23 @@ */ import * as readline from 'readline'; -import { GuardrailsOpenAI, GuardrailTripwireTriggered } from '../../src'; +import { GuardrailsOpenAI, GuardrailTripwireTriggered, totalGuardrailTokenUsage } from '../../src'; -// Pipeline configuration with preflight PII masking and input guardrails +// Pipeline configuration with preflight and input guardrails const PIPELINE_CONFIG = { version: 1, pre_flight: { version: 1, guardrails: [ { - name: 'Contains PII', + name: 'Moderation', + config: { categories: ['hate', 'violence'] }, + }, + { + name: 'Jailbreak', config: { - entities: ['US_SSN', 'PHONE_NUMBER', 'EMAIL_ADDRESS'], - block: true, // Use masking mode (default) - masks PII without blocking + model: 'gpt-4.1-mini', + confidence_threshold: 0.7, }, }, ], @@ -76,6 +80,10 @@ async function processInput( // Show guardrail results if any were run if (response.guardrail_results.allResults.length > 0) { console.log(`[dim]Guardrails checked: ${response.guardrail_results.allResults.length}[/dim]`); + const usage = totalGuardrailTokenUsage(response); + if (usage.total_tokens !== null) { + console.log(`[dim]Token usage: ${JSON.stringify(usage)}[/dim]`); + } } return response.id; diff --git a/examples/basic/local_model.ts b/examples/basic/local_model.ts index 08d5c7c..42309e9 100644 --- a/examples/basic/local_model.ts +++ b/examples/basic/local_model.ts @@ -2,7 +2,7 @@ * Example: Guardrail bundle using Ollama's Gemma3 model with GuardrailsClient. */ -import { GuardrailsOpenAI, GuardrailTripwireTriggered } from '../../src'; +import { GuardrailsOpenAI, GuardrailTripwireTriggered, totalGuardrailTokenUsage } from '../../src'; import * as readline from 'readline'; import { OpenAI } from 'openai'; @@ -46,6 +46,10 @@ async function processInput( // Access response content using standard OpenAI API const responseContent = response.choices[0].message.content ?? ''; console.log(`\nAssistant output: ${responseContent}\n`); + const usage = totalGuardrailTokenUsage(response); + if (usage.total_tokens !== null) { + console.log(`Token usage: ${usage.total_tokens}`); + } // Guardrails passed - now safe to add to conversation history conversation.push({ role: 'user', content: userInput }); diff --git a/examples/basic/multiturn_with_prompt_injection_detection.ts b/examples/basic/multiturn_with_prompt_injection_detection.ts index 62e1f03..0a30df3 100644 --- a/examples/basic/multiturn_with_prompt_injection_detection.ts +++ b/examples/basic/multiturn_with_prompt_injection_detection.ts @@ -26,7 +26,12 @@ */ import * as readline from 'readline'; -import { GuardrailsOpenAI, GuardrailTripwireTriggered, GuardrailsResponse } from '../../src'; +import { + GuardrailsOpenAI, + GuardrailTripwireTriggered, + GuardrailsResponse, + totalGuardrailTokenUsage, +} from '../../src'; // Tool implementations (mocked) function get_horoscope(sign: string): { horoscope: string } { @@ -299,6 +304,15 @@ async function main(malicious: boolean = false): Promise { printGuardrailResults('initial', response); + const initialUsage = totalGuardrailTokenUsage(response); + if (initialUsage.total_tokens !== null) { + console.log( + `[dim]Guardrail tokens (initial): ${initialUsage.total_tokens} Ā· prompt=${ + initialUsage.prompt_tokens ?? 0 + }, completion=${initialUsage.completion_tokens ?? 0}[/dim]` + ); + } + assistantOutputs = response.output ?? []; // Guardrails passed - now safe to add user message to conversation history @@ -394,6 +408,14 @@ async function main(malicious: boolean = false): Promise { }); printGuardrailResults('final', response); + const finalUsage = totalGuardrailTokenUsage(response); + if (finalUsage.total_tokens !== null) { + console.log( + `[dim]Guardrail tokens (final): ${finalUsage.total_tokens} Ā· prompt=${ + finalUsage.prompt_tokens ?? 0 + }, completion=${finalUsage.completion_tokens ?? 0}[/dim]` + ); + } console.log(`\nšŸ¤– Assistant: ${response.output_text}`); // Guardrails passed - now safe to add tool results and assistant responses to history diff --git a/examples/basic/streaming.ts b/examples/basic/streaming.ts index 13965f5..669eb7f 100644 --- a/examples/basic/streaming.ts +++ b/examples/basic/streaming.ts @@ -3,7 +3,7 @@ * Streams output using console logging. */ -import { GuardrailsOpenAI, GuardrailTripwireTriggered } from '../../src'; +import { GuardrailsOpenAI, GuardrailTripwireTriggered, totalGuardrailTokenUsage } from '../../src'; import * as readline from 'readline'; // Define your pipeline configuration @@ -14,10 +14,14 @@ const PIPELINE_CONFIG = { version: 1, guardrails: [ { - name: 'Contains PII', + name: 'Moderation', + config: { categories: ['hate', 'violence'] }, + }, + { + name: 'Jailbreak', config: { - entities: ['US_SSN', 'PHONE_NUMBER', 'EMAIL_ADDRESS'], - block: false, // Use masking mode (default) - masks PII without blocking + model: 'gpt-4.1-mini', + confidence_threshold: 0.7, }, }, ], @@ -49,6 +53,7 @@ const PIPELINE_CONFIG = { config: { entities: ['US_SSN', 'PHONE_NUMBER', 'EMAIL_ADDRESS'], block: true, // Use blocking mode on output + detect_encoded_pii: false, }, }, ], @@ -78,8 +83,10 @@ async function processInput( console.log(outputText); let responseIdToReturn: string | null = null; + let lastChunk: unknown = null; for await (const chunk of stream) { + lastChunk = chunk; // Access streaming response exactly like native OpenAI API if ('delta' in chunk && chunk.delta && typeof chunk.delta === 'string') { outputText += chunk.delta; @@ -99,6 +106,14 @@ async function processInput( } console.log(); // New line after streaming + + if (lastChunk) { + const usage = totalGuardrailTokenUsage(lastChunk); + if (usage.total_tokens !== null) { + console.log(`[dim]šŸ“Š Guardrail tokens: ${usage.total_tokens}[/dim]`); + } + } + return responseIdToReturn; } catch (error) { if (error instanceof GuardrailTripwireTriggered) { diff --git a/src/__tests__/unit/agents.test.ts b/src/__tests__/unit/agents.test.ts index 77a0df9..0715ba5 100644 --- a/src/__tests__/unit/agents.test.ts +++ b/src/__tests__/unit/agents.test.ts @@ -451,7 +451,7 @@ describe('GuardrailAgent', () => { expect(result.outputInfo.input).toBe('Latest user message with additional context.'); }); - it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => { + it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => { process.env.OPENAI_API_KEY = 'test'; const config = { version: 1, @@ -547,4 +547,68 @@ describe('GuardrailAgent', () => { ); }); }); + + it('propagates guardrail metadata to outputInfo on success', async () => { + process.env.OPENAI_API_KEY = 'test'; + const config = { + version: 1, + input: { + version: 1, + guardrails: [{ name: 'Jailbreak', config: {} }], + }, + }; + + const { instantiateGuardrails } = await import('../../runtime'); + vi.mocked(instantiateGuardrails).mockImplementationOnce(() => + Promise.resolve([ + { + definition: { + name: 'Jailbreak', + description: 'Test guardrail', + mediaType: 'text/plain', + configSchema: z.object({}), + checkFn: vi.fn(), + metadata: {}, + ctxRequirements: z.object({}), + schema: () => ({}), + instantiate: vi.fn(), + }, + config: {}, + run: vi.fn().mockResolvedValue({ + tripwireTriggered: false, + info: { + guardrail_name: 'Jailbreak', + flagged: false, + token_usage: { + prompt_tokens: 42, + completion_tokens: 10, + total_tokens: 52, + }, + }, + }), + } as unknown as Parameters[0] extends Promise + ? T extends readonly (infer U)[] + ? U + : never + : never, + ]) + ); + + const agent = (await GuardrailAgent.create( + config, + 'Metadata Agent', + 'Test instructions' + )) as MockAgent; + + const guardrailFunction = agent.inputGuardrails[0]; + const result = await guardrailFunction.execute('payload'); + + expect(result.tripwireTriggered).toBe(false); + expect(result.outputInfo.guardrail_name).toBe('Jailbreak'); + expect(result.outputInfo.token_usage).toEqual({ + prompt_tokens: 42, + completion_tokens: 10, + total_tokens: 52, + }); + }); }); diff --git a/src/__tests__/unit/base-client.test.ts b/src/__tests__/unit/base-client.test.ts index 3b03ee6..4dc0a38 100644 --- a/src/__tests__/unit/base-client.test.ts +++ b/src/__tests__/unit/base-client.test.ts @@ -368,3 +368,64 @@ describe('GuardrailsBaseClient helpers', () => { }); }); }); + +describe('GuardrailResultsImpl token usage', () => { + it('aggregates usage across all stages', () => { + const results = new GuardrailResultsImpl( + [ + { + tripwireTriggered: false, + info: { + guardrail_name: 'Jailbreak', + token_usage: { + prompt_tokens: 100, + completion_tokens: 40, + total_tokens: 140, + }, + }, + }, + ], + [], + [ + { + tripwireTriggered: false, + info: { + guardrail_name: 'NSFW', + token_usage: { + prompt_tokens: 50, + completion_tokens: 30, + total_tokens: 80, + }, + }, + }, + ] + ); + + expect(results.totalTokenUsage).toEqual({ + prompt_tokens: 150, + completion_tokens: 70, + total_tokens: 220, + }); + }); + + it('returns null totals when no guardrail reports usage', () => { + const results = new GuardrailResultsImpl( + [ + { + tripwireTriggered: false, + info: { + guardrail_name: 'Moderation', + }, + }, + ], + [], + [] + ); + + expect(results.totalTokenUsage).toEqual({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + }); + }); +}); diff --git a/src/__tests__/unit/checks/jailbreak.test.ts b/src/__tests__/unit/checks/jailbreak.test.ts index 01b8d0d..ae2c177 100644 --- a/src/__tests__/unit/checks/jailbreak.test.ts +++ b/src/__tests__/unit/checks/jailbreak.test.ts @@ -39,11 +39,18 @@ describe('jailbreak guardrail', () => { it('passes trimmed latest input and recent history to runLLM', async () => { const { jailbreak, MAX_CONTEXT_TURNS } = await import('../../../checks/jailbreak'); - runLLMMock.mockResolvedValue({ - flagged: true, - confidence: 0.92, - reason: 'Detected escalation.', - }); + runLLMMock.mockResolvedValue([ + { + flagged: true, + confidence: 0.92, + reason: 'Detected escalation.', + }, + { + prompt_tokens: 120, + completion_tokens: 40, + total_tokens: 160, + }, + ]); const history = Array.from({ length: MAX_CONTEXT_TURNS + 2 }, (_, i) => ({ role: 'user', @@ -76,16 +83,28 @@ describe('jailbreak guardrail', () => { expect(result.tripwireTriggered).toBe(true); expect(result.info.used_conversation_history).toBe(true); expect(result.info.reason).toBe('Detected escalation.'); + expect(result.info.token_usage).toEqual({ + prompt_tokens: 120, + completion_tokens: 40, + total_tokens: 160, + }); }); it('falls back to latest input when no history is available', async () => { const { jailbreak } = await import('../../../checks/jailbreak'); - runLLMMock.mockResolvedValue({ - flagged: false, - confidence: 0.1, - reason: 'Benign request.', - }); + runLLMMock.mockResolvedValue([ + { + flagged: false, + confidence: 0.1, + reason: 'Benign request.', + }, + { + prompt_tokens: 60, + completion_tokens: 20, + total_tokens: 80, + }, + ]); const context = { guardrailLlm: {} as unknown, @@ -106,18 +125,31 @@ describe('jailbreak guardrail', () => { expect(result.tripwireTriggered).toBe(false); expect(result.info.used_conversation_history).toBe(false); expect(result.info.threshold).toBe(0.8); + expect(result.info.token_usage).toEqual({ + prompt_tokens: 60, + completion_tokens: 20, + total_tokens: 80, + }); }); it('uses createErrorResult when runLLM returns an error output', async () => { const { jailbreak } = await import('../../../checks/jailbreak'); - runLLMMock.mockResolvedValue({ - flagged: false, - confidence: 0, - info: { - error_message: 'timeout', + runLLMMock.mockResolvedValue([ + { + flagged: false, + confidence: 0, + info: { + error_message: 'timeout', + }, }, - }); + { + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }, + ]); const context = { guardrailLlm: {} as unknown, @@ -134,5 +166,11 @@ describe('jailbreak guardrail', () => { expect(result.info.error_message).toBe('timeout'); expect(result.info.checked_text).toBeDefined(); expect(result.info.used_conversation_history).toBe(true); + expect(result.info.token_usage).toEqual({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }); }); }); diff --git a/src/__tests__/unit/llm-base.test.ts b/src/__tests__/unit/llm-base.test.ts index c9873bb..07684fe 100644 --- a/src/__tests__/unit/llm-base.test.ts +++ b/src/__tests__/unit/llm-base.test.ts @@ -118,6 +118,11 @@ describe('LLM Base', () => { }, }, ], + usage: { + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30, + }, }), }, }, @@ -133,6 +138,11 @@ describe('LLM Base', () => { expect(result.info.guardrail_name).toBe('Test Guardrail'); expect(result.info.flagged).toBe(true); expect(result.info.confidence).toBe(0.8); + expect(result.info.token_usage).toEqual({ + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30, + }); }); it('should fail open on schema validation error and not trigger tripwire', async () => { @@ -155,6 +165,11 @@ describe('LLM Base', () => { }, }, ], + usage: { + prompt_tokens: 12, + completion_tokens: 4, + total_tokens: 16, + }, }), }, }, @@ -170,7 +185,13 @@ describe('LLM Base', () => { expect(result.executionFailed).toBe(true); expect(result.info.flagged).toBe(false); expect(result.info.confidence).toBe(0.0); - expect(result.info.info.error_message).toBe('LLM response validation failed.'); + expect(result.info.error_message).toBe('LLM response validation failed.'); + expect(result.info.token_usage).toEqual({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }); }); it('should fail open on malformed JSON and not trigger tripwire', async () => { @@ -193,6 +214,11 @@ describe('LLM Base', () => { }, }, ], + usage: { + prompt_tokens: 8, + completion_tokens: 3, + total_tokens: 11, + }, }), }, }, @@ -208,7 +234,13 @@ describe('LLM Base', () => { expect(result.executionFailed).toBe(true); expect(result.info.flagged).toBe(false); expect(result.info.confidence).toBe(0.0); - expect(result.info.info.error_message).toBe('LLM returned non-JSON or malformed JSON.'); + expect(result.info.error_message).toBe('LLM returned non-JSON or malformed JSON.'); + expect(result.info.token_usage).toEqual({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }); }); }); }); diff --git a/src/__tests__/unit/prompt_injection_detection.test.ts b/src/__tests__/unit/prompt_injection_detection.test.ts index 4ff5cfe..c352e16 100644 --- a/src/__tests__/unit/prompt_injection_detection.test.ts +++ b/src/__tests__/unit/prompt_injection_detection.test.ts @@ -132,6 +132,7 @@ describe('Prompt Injection Detection Check', () => { expect(result.info.confidence).toBeLessThan(config.confidence_threshold); expect(result.info.guardrail_name).toBe('Prompt Injection Detection'); expect(result.info.evidence).toBeNull(); + expect(result.info.token_usage).toBeDefined(); }); it('should handle context with previous messages', async () => { diff --git a/src/__tests__/unit/types.test.ts b/src/__tests__/unit/types.test.ts index 3554f2a..79cc7ae 100644 --- a/src/__tests__/unit/types.test.ts +++ b/src/__tests__/unit/types.test.ts @@ -9,7 +9,13 @@ */ import { describe, it, expect } from 'vitest'; -import { GuardrailResult, GuardrailLLMContext } from '../../types'; +import { + GuardrailResult, + GuardrailLLMContext, + aggregateTokenUsageFromInfos, + extractTokenUsage, + totalGuardrailTokenUsage, +} from '../../types'; import { OpenAI } from 'openai'; describe('Types Module', () => { @@ -145,3 +151,81 @@ describe('Types Module', () => { }); }); }); + +describe('token usage helpers', () => { + it('extractTokenUsage returns counts when usage present', () => { + const usage = extractTokenUsage({ + usage: { + prompt_tokens: 12, + completion_tokens: 4, + total_tokens: 16, + }, + }); + + expect(usage).toEqual({ + prompt_tokens: 12, + completion_tokens: 4, + total_tokens: 16, + }); + }); + + it('aggregateTokenUsageFromInfos sums values', () => { + const summary = aggregateTokenUsageFromInfos([ + { + token_usage: { + prompt_tokens: 50, + completion_tokens: 10, + total_tokens: 60, + }, + }, + { + token_usage: { + prompt_tokens: 25, + completion_tokens: 5, + total_tokens: 30, + }, + }, + ]); + + expect(summary).toEqual({ + prompt_tokens: 75, + completion_tokens: 15, + total_tokens: 90, + }); + }); + + it('totalGuardrailTokenUsage handles GuardrailResults objects', () => { + const guardrailResults = { + totalTokenUsage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }; + + const totals = totalGuardrailTokenUsage({ guardrail_results: guardrailResults }); + expect(totals).toEqual({ + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }); + }); + + it('totalGuardrailTokenUsage handles camelCase guardrailResults property', () => { + const result = totalGuardrailTokenUsage({ + guardrailResults: { + totalTokenUsage: { + prompt_tokens: 25, + completion_tokens: 10, + total_tokens: 35, + }, + }, + }); + + expect(result).toEqual({ + prompt_tokens: 25, + completion_tokens: 10, + total_tokens: 35, + }); + }); +}); diff --git a/src/base-client.ts b/src/base-client.ts index a7b4176..86c5c66 100644 --- a/src/base-client.ts +++ b/src/base-client.ts @@ -13,6 +13,8 @@ import { Message, ContentPart, TextContentPart, + TokenUsageSummary, + aggregateTokenUsageFromInfos, } from './types'; import { ContentUtils } from './utils/content'; import { @@ -81,6 +83,7 @@ export interface GuardrailResults { readonly allResults: GuardrailResult[]; readonly tripwiresTriggered: boolean; readonly triggeredResults: GuardrailResult[]; + readonly totalTokenUsage: TokenUsageSummary; } /** @@ -104,6 +107,10 @@ export class GuardrailResultsImpl implements GuardrailResults { get triggeredResults(): GuardrailResult[] { return this.allResults.filter((r) => r.tripwireTriggered); } + + get totalTokenUsage(): TokenUsageSummary { + return aggregateTokenUsageFromInfos(this.allResults.map((result) => result.info)); + } } /** diff --git a/src/checks/hallucination-detection.ts b/src/checks/hallucination-detection.ts index f4d995d..3cd02f6 100644 --- a/src/checks/hallucination-detection.ts +++ b/src/checks/hallucination-detection.ts @@ -18,7 +18,14 @@ */ import { z } from 'zod'; -import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types'; +import { + CheckFn, + GuardrailResult, + GuardrailLLMContext, + TokenUsage, + extractTokenUsage, + tokenUsageToDict, +} from '../types'; import { defaultSpecRegistry } from '../registry'; import { createErrorResult, LLMErrorOutput } from './llm-base'; @@ -159,6 +166,13 @@ export const hallucination_detection: CheckFn< throw new Error("knowledge_source must be a valid vector store ID starting with 'vs_'"); } + let tokenUsage: TokenUsage = Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }); + try { // Create the validation query const validationQuery = `${VALIDATION_PROMPT}\n\nText to validate:\n${candidate}`; @@ -175,6 +189,8 @@ export const hallucination_detection: CheckFn< ], }); + tokenUsage = extractTokenUsage(response); + // Extract the analysis from the response // The response will contain the LLM's analysis in output_text const outputText = response.output_text; @@ -203,13 +219,18 @@ export const hallucination_detection: CheckFn< confidence: 0.0, info: { error_message: `JSON parsing failed: ${error instanceof Error ? error.message : String(error)}` }, }; - return createErrorResult('Hallucination Detection', errorOutput, { - threshold: config.confidence_threshold, - reasoning: 'LLM response could not be parsed as JSON', - hallucination_type: null, - hallucinated_statements: null, - verified_statements: null, - }); + return createErrorResult( + 'Hallucination Detection', + errorOutput, + { + threshold: config.confidence_threshold, + reasoning: 'LLM response could not be parsed as JSON', + hallucination_type: null, + hallucinated_statements: null, + verified_statements: null, + }, + tokenUsage + ); } const analysis = HallucinationDetectionOutput.parse(parsedJson); @@ -228,6 +249,7 @@ export const hallucination_detection: CheckFn< hallucinated_statements: analysis.hallucinated_statements, verified_statements: analysis.verified_statements, threshold: config.confidence_threshold, + token_usage: tokenUsageToDict(tokenUsage), }, }; } catch (error) { @@ -238,13 +260,18 @@ export const hallucination_detection: CheckFn< confidence: 0.0, info: { error_message: error instanceof Error ? error.message : String(error) }, }; - return createErrorResult('Hallucination Detection', errorOutput, { - threshold: config.confidence_threshold, - reasoning: `Analysis failed: ${error instanceof Error ? error.message : String(error)}`, - hallucination_type: null, - hallucinated_statements: null, - verified_statements: null, - }); + return createErrorResult( + 'Hallucination Detection', + errorOutput, + { + threshold: config.confidence_threshold, + reasoning: `Analysis failed: ${error instanceof Error ? error.message : String(error)}`, + hallucination_type: null, + hallucinated_statements: null, + verified_statements: null, + }, + tokenUsage + ); } }; diff --git a/src/checks/jailbreak.ts b/src/checks/jailbreak.ts index 49215bd..6da0e6b 100644 --- a/src/checks/jailbreak.ts +++ b/src/checks/jailbreak.ts @@ -8,7 +8,7 @@ */ import { z } from 'zod'; -import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types'; +import { CheckFn, GuardrailResult, GuardrailLLMContext, tokenUsageToDict } from '../types'; import { LLMConfig, LLMOutput, LLMErrorOutput, createErrorResult, runLLM } from './llm-base'; import { defaultSpecRegistry } from '../registry'; @@ -224,7 +224,7 @@ export const jailbreak: CheckFn = asy const conversationHistory = extractConversationHistory(ctx); const analysisPayload = buildAnalysisPayload(conversationHistory, data); - const analysis = await runLLM( + const [analysis, tokenUsage] = await runLLM( analysisPayload, SYSTEM_PROMPT, ctx.guardrailLlm, @@ -235,10 +235,15 @@ export const jailbreak: CheckFn = asy const usedConversationHistory = conversationHistory.length > 0; if (isLLMErrorOutput(analysis)) { - return createErrorResult('Jailbreak', analysis, { - checked_text: analysisPayload, - used_conversation_history: usedConversationHistory, - }); + return createErrorResult( + 'Jailbreak', + analysis, + { + checked_text: analysisPayload, + used_conversation_history: usedConversationHistory, + }, + tokenUsage + ); } const isTriggered = analysis.flagged && analysis.confidence >= config.confidence_threshold; @@ -251,6 +256,7 @@ export const jailbreak: CheckFn = asy threshold: config.confidence_threshold, checked_text: analysisPayload, used_conversation_history: usedConversationHistory, + token_usage: tokenUsageToDict(tokenUsage), }, }; }; diff --git a/src/checks/llm-base.ts b/src/checks/llm-base.ts index 02b245f..a528130 100644 --- a/src/checks/llm-base.ts +++ b/src/checks/llm-base.ts @@ -9,7 +9,14 @@ import { z, ZodTypeAny } from 'zod'; import { OpenAI } from 'openai'; -import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types'; +import { + CheckFn, + GuardrailResult, + GuardrailLLMContext, + TokenUsage, + extractTokenUsage, + tokenUsageToDict, +} from '../types'; import { defaultSpecRegistry } from '../registry'; import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier'; @@ -78,7 +85,8 @@ export type LLMErrorOutput = z.infer; export function createErrorResult( guardrailName: string, analysis: LLMErrorOutput, - additionalInfo: Record = {} + additionalInfo: Record = {}, + tokenUsage?: TokenUsage ): GuardrailResult { return { tripwireTriggered: false, @@ -90,6 +98,7 @@ export function createErrorResult( confidence: analysis.confidence, ...analysis.info, ...additionalInfo, + ...(tokenUsage ? { token_usage: tokenUsageToDict(tokenUsage) } : {}), }, }; } @@ -273,8 +282,14 @@ export async function runLLM( client: OpenAI, model: string, outputModel: TOutput -): Promise | LLMErrorOutput> { +): Promise<[z.infer | LLMErrorOutput, TokenUsage]> { const fullPrompt = buildFullPrompt(systemPrompt, outputModel); + const noUsage: TokenUsage = Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'LLM call failed before usage could be recorded', + }); try { // Handle temperature based on model capabilities @@ -304,68 +319,84 @@ export async function runLLM( // @ts-ignore - safety_identifier is not in the OpenAI types yet const response = await client.chat.completions.create(params); + const tokenUsage = extractTokenUsage(response); const result = response.choices[0]?.message?.content; if (!result) { - return LLMErrorOutput.parse({ - flagged: false, - confidence: 0.0, - info: { - error_message: 'LLM returned no content', - }, - }); + return [ + LLMErrorOutput.parse({ + flagged: false, + confidence: 0.0, + info: { + error_message: 'LLM returned no content', + }, + }), + tokenUsage, + ]; } const cleanedResult = stripJsonCodeFence(result); - return outputModel.parse(JSON.parse(cleanedResult)); + return [outputModel.parse(JSON.parse(cleanedResult)), tokenUsage]; } catch (error) { console.error('LLM guardrail failed for prompt:', systemPrompt, error); // Check if this is a content filter error - Azure OpenAI if (error && typeof error === 'string' && error.includes('content_filter')) { console.warn('Content filter triggered by provider:', error); - return LLMErrorOutput.parse({ - flagged: true, - confidence: 1.0, - info: { - third_party_filter: true, - error_message: String(error), - }, - }); + return [ + LLMErrorOutput.parse({ + flagged: true, + confidence: 1.0, + info: { + third_party_filter: true, + error_message: String(error), + }, + }), + noUsage, + ]; } // Fail-open on JSON parsing errors (malformed or non-JSON responses) if (error instanceof SyntaxError || (error as Error)?.constructor?.name === 'SyntaxError') { console.warn('LLM returned non-JSON or malformed JSON.', error); - return LLMErrorOutput.parse({ - flagged: false, - confidence: 0.0, - info: { - error_message: 'LLM returned non-JSON or malformed JSON.', - }, - }); + return [ + LLMErrorOutput.parse({ + flagged: false, + confidence: 0.0, + info: { + error_message: 'LLM returned non-JSON or malformed JSON.', + }, + }), + noUsage, + ]; } // Fail-open on schema validation errors (e.g., wrong types like confidence as string) if (error instanceof z.ZodError) { console.warn('LLM response validation failed.', error); - return LLMErrorOutput.parse({ + return [ + LLMErrorOutput.parse({ + flagged: false, + confidence: 0.0, + info: { + error_message: 'LLM response validation failed.', + zod_issues: error.issues ?? [], + }, + }), + noUsage, + ]; + } + + // Always return error information for other LLM failures + return [ + LLMErrorOutput.parse({ flagged: false, confidence: 0.0, info: { - error_message: 'LLM response validation failed.', - zod_issues: error.issues ?? [], + error_message: String(error), }, - }); - } - - // Always return error information for other LLM failures - return LLMErrorOutput.parse({ - flagged: false, - confidence: 0.0, - info: { - error_message: String(error), - }, - }); + }), + noUsage, + ]; } } @@ -408,7 +439,7 @@ export function createLLMCheckFn( ); } - const analysis = await runLLM( + const [analysis, tokenUsage] = await runLLM( data, renderedSystemPrompt, ctx.guardrailLlm as OpenAI, // Type assertion to handle OpenAI client compatibility @@ -417,15 +448,7 @@ export function createLLMCheckFn( ); if (isLLMErrorOutput(analysis)) { - return { - tripwireTriggered: false, - executionFailed: true, - originalException: new Error(String(analysis.info?.error_message || 'LLM execution failed')), - info: { - guardrail_name: name, - ...analysis, - }, - }; + return createErrorResult(name, analysis, undefined, tokenUsage); } const isTrigger = analysis.flagged && analysis.confidence >= config.confidence_threshold; @@ -435,6 +458,7 @@ export function createLLMCheckFn( guardrail_name: name, ...analysis, threshold: config.confidence_threshold, + token_usage: tokenUsageToDict(tokenUsage), }, }; } diff --git a/src/checks/prompt_injection_detection.ts b/src/checks/prompt_injection_detection.ts index dbf5842..3891a30 100644 --- a/src/checks/prompt_injection_detection.ts +++ b/src/checks/prompt_injection_detection.ts @@ -17,6 +17,8 @@ import { GuardrailLLMContext, GuardrailLLMContextWithHistory, ConversationMessage, + TokenUsage, + tokenUsageToDict, } from '../types'; import { defaultSpecRegistry } from '../registry'; import { LLMOutput, runLLM } from './llm-base'; @@ -220,7 +222,11 @@ export const promptInjectionDetectionCheck: CheckFn< } const analysisPrompt = buildAnalysisPrompt(userGoalText, recentMessages, actionableMessages); - const analysis = await callPromptInjectionDetectionLLM(ctx, analysisPrompt, config); + const { analysis, tokenUsage } = await callPromptInjectionDetectionLLM( + ctx, + analysisPrompt, + config + ); const isMisaligned = analysis.flagged && analysis.confidence >= config.confidence_threshold; @@ -237,6 +243,7 @@ export const promptInjectionDetectionCheck: CheckFn< action: actionableMessages, recent_messages: recentMessages, recent_messages_json: checkedText, + token_usage: tokenUsageToDict(tokenUsage), }, }; } catch (error) { @@ -360,13 +367,21 @@ function isActionableMessage(message: NormalizedConversationEntry): boolean { return false; } +const SKIPPED_USAGE: TokenUsage = Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'No LLM call made (check was skipped)', +}); + function createSkipResult( observation: string, threshold: number, recentMessagesJson: string, userGoal: string = 'N/A', action: ConversationMessage[] = [], - recentMessages: ConversationMessage[] = [] + recentMessages: ConversationMessage[] = [], + tokenUsage: TokenUsage = SKIPPED_USAGE ): GuardrailResult { return { tripwireTriggered: false, @@ -381,6 +396,7 @@ function createSkipResult( action: action ?? [], recent_messages: recentMessages, recent_messages_json: recentMessagesJson, + token_usage: tokenUsageToDict(tokenUsage), }, }; } @@ -429,9 +445,23 @@ async function callPromptInjectionDetectionLLM( ctx: GuardrailLLMContext, prompt: string, config: PromptInjectionDetectionConfig -): Promise { +): Promise<{ analysis: PromptInjectionDetectionOutput; tokenUsage: TokenUsage }> { + const fallbackOutput: PromptInjectionDetectionOutput = { + flagged: false, + confidence: 0.0, + observation: 'LLM analysis failed - using fallback values', + evidence: null, + }; + + const fallbackUsage: TokenUsage = Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'Prompt injection detection LLM call failed', + }); + try { - const result = await runLLM( + const [result, tokenUsage] = await runLLM( prompt, '', ctx.guardrailLlm, @@ -439,14 +469,23 @@ async function callPromptInjectionDetectionLLM( PromptInjectionDetectionOutput ); - return PromptInjectionDetectionOutput.parse(result); - } catch { - console.warn('Prompt injection detection LLM call failed, using fallback'); + try { + return { + analysis: PromptInjectionDetectionOutput.parse(result), + tokenUsage, + }; + } catch (parseError) { + console.warn('Prompt injection detection LLM parsing failed, using fallback', parseError); + return { + analysis: fallbackOutput, + tokenUsage, + }; + } + } catch (error) { + console.warn('Prompt injection detection LLM call failed, using fallback', error); return { - flagged: false, - confidence: 0.0, - observation: 'LLM analysis failed - using fallback values', - evidence: null, + analysis: fallbackOutput, + tokenUsage: fallbackUsage, }; } } diff --git a/src/index.ts b/src/index.ts index 35fcc8f..29fb296 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,7 +7,7 @@ */ // Core types and interfaces -export { GuardrailResult, GuardrailLLMContext, CheckFn } from './types'; +export { GuardrailResult, GuardrailLLMContext, CheckFn, totalGuardrailTokenUsage } from './types'; // Exception types export { @@ -65,4 +65,4 @@ export { GuardrailAgent } from './agents'; export { main as cli } from './cli'; // Re-export commonly used types -export type { MaybeAwaitableResult } from './types'; +export type { MaybeAwaitableResult, TokenUsage, TokenUsageSummary } from './types'; diff --git a/src/types.ts b/src/types.ts index 5dfea81..6275c70 100644 --- a/src/types.ts +++ b/src/types.ts @@ -155,3 +155,249 @@ export type TextOnlyMessage = { /** Array of text-only messages */ export type TextOnlyMessageArray = TextOnlyMessage[]; + +/** + * Token usage statistics emitted by LLM-based guardrails. + */ +export type TokenUsage = Readonly<{ + prompt_tokens: number | null; + completion_tokens: number | null; + total_tokens: number | null; + unavailable_reason?: string | null; +}>; + +/** + * Aggregated token usage summary across multiple guardrails. + */ +export type TokenUsageSummary = Readonly<{ + prompt_tokens: number | null; + completion_tokens: number | null; + total_tokens: number | null; +}>; + +type UsageRecord = { + prompt_tokens?: unknown; + completion_tokens?: unknown; + total_tokens?: unknown; + input_tokens?: unknown; + output_tokens?: unknown; +}; + +const EMPTY_TOKEN_USAGE_SUMMARY: TokenUsageSummary = Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, +}); + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function isIterable(value: unknown): value is Iterable { + return typeof value === 'object' && value !== null && typeof (value as Iterable)[Symbol.iterator] === 'function'; +} + +function readNumber(value: unknown): number | null { + return typeof value === 'number' && Number.isFinite(value) ? value : null; +} + +function pickNumber(record: UsageRecord | null | undefined, keys: (keyof UsageRecord)[]): number | null { + if (!record) { + return null; + } + + for (const key of keys) { + const candidate = record[key]; + const numeric = readNumber(candidate); + if (numeric !== null) { + return numeric; + } + } + + return null; +} + +/** + * Extract token usage data from an OpenAI API response object. + */ +export function extractTokenUsage(response: unknown): TokenUsage { + const usage = (response as { usage?: UsageRecord | null })?.usage; + if (!usage) { + return Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'Token usage not available for this model provider', + }) as TokenUsage; + } + + const promptTokens = pickNumber(usage, ['prompt_tokens', 'input_tokens']); + const completionTokens = pickNumber(usage, ['completion_tokens', 'output_tokens']); + const totalTokens = pickNumber(usage, ['total_tokens']); + + if (promptTokens === null && completionTokens === null && totalTokens === null) { + return Object.freeze({ + prompt_tokens: null, + completion_tokens: null, + total_tokens: null, + unavailable_reason: 'Token usage data not populated in response', + }) as TokenUsage; + } + + return Object.freeze({ + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: totalTokens, + }) as TokenUsage; +} + +/** + * Convert a TokenUsage object into a plain dictionary suitable for serialization. + */ +export function tokenUsageToDict(tokenUsage: TokenUsage): TokenUsage { + const result: Record & { unavailable_reason?: string | null } = { + prompt_tokens: tokenUsage.prompt_tokens, + completion_tokens: tokenUsage.completion_tokens, + total_tokens: tokenUsage.total_tokens, + }; + + if (tokenUsage.unavailable_reason !== undefined) { + result.unavailable_reason = tokenUsage.unavailable_reason; + } + + return Object.freeze(result) as TokenUsage; +} + +/** + * Aggregate token usage values from a collection of guardrail info dictionaries. + */ +export function aggregateTokenUsageFromInfos( + infoDicts: Iterable | null | undefined> +): TokenUsageSummary { + let totalPrompt = 0; + let totalCompletion = 0; + let totalTokens = 0; + let hasData = false; + + for (const info of infoDicts) { + if (!info) { + continue; + } + + const usage = info.token_usage; + if (!isRecord(usage)) { + continue; + } + + const prompt = readNumber(usage.prompt_tokens); + const completion = readNumber(usage.completion_tokens); + const total = readNumber(usage.total_tokens); + + if (prompt === null && completion === null && total === null) { + continue; + } + + hasData = true; + if (prompt !== null) { + totalPrompt += prompt; + } + if (completion !== null) { + totalCompletion += completion; + } + if (total !== null) { + totalTokens += total; + } + } + + if (!hasData) { + return EMPTY_TOKEN_USAGE_SUMMARY; + } + + return Object.freeze({ + prompt_tokens: totalPrompt, + completion_tokens: totalCompletion, + total_tokens: totalTokens, + }) as TokenUsageSummary; +} + +const AGENT_RESULT_ATTRS = [ + 'input_guardrail_results', + 'output_guardrail_results', + 'tool_input_guardrail_results', + 'tool_output_guardrail_results', + 'inputGuardrailResults', + 'outputGuardrailResults', + 'toolInputGuardrailResults', + 'toolOutputGuardrailResults', +] as const; + +function extractAgentsSdkInfos(stageResults: unknown): Record[] { + if (!stageResults) { + return []; + } + + const entries: unknown[] = Array.isArray(stageResults) + ? stageResults + : isIterable(stageResults) + ? Array.from(stageResults as Iterable) + : []; + + const infos: Record[] = []; + for (const entry of entries) { + if (!isRecord(entry)) { + continue; + } + + const direct = entry.output_info ?? entry.outputInfo; + if (isRecord(direct)) { + infos.push(direct); + continue; + } + + const output = entry.output; + if (isRecord(output)) { + const nested = output.output_info ?? output.outputInfo; + if (isRecord(nested)) { + infos.push(nested); + } + } + } + + return infos; +} + +/** + * Unified helper to compute total guardrail token usage from any result shape. + */ +export function totalGuardrailTokenUsage(result: unknown): TokenUsageSummary { + if (!isRecord(result)) { + return EMPTY_TOKEN_USAGE_SUMMARY; + } + + const guardrailResults = result.guardrail_results ?? result.guardrailResults; + if (isRecord(guardrailResults)) { + const totals = (guardrailResults as { totalTokenUsage?: TokenUsageSummary }).totalTokenUsage; + if (totals) { + return totals; + } + } + + const directTotals = (result as { totalTokenUsage?: TokenUsageSummary }).totalTokenUsage; + if (directTotals) { + return directTotals; + } + + const infos: Record[] = []; + for (const attr of AGENT_RESULT_ATTRS) { + const stageResults = result[attr]; + if (stageResults) { + infos.push(...extractAgentsSdkInfos(stageResults)); + } + } + + if (infos.length === 0) { + return EMPTY_TOKEN_USAGE_SUMMARY; + } + + return aggregateTokenUsageFromInfos(infos); +} From 30343deae7e194e8e4846bea1e8fdb906614d263 Mon Sep 17 00:00:00 2001 From: Steven C Date: Tue, 2 Dec 2025 12:31:07 -0500 Subject: [PATCH 2/3] Return token count even during execution error --- src/__tests__/unit/agents.test.ts | 2 +- src/__tests__/unit/llm-base.test.ts | 16 ++++++++-------- src/checks/llm-base.ts | 13 ++++++++++--- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/__tests__/unit/agents.test.ts b/src/__tests__/unit/agents.test.ts index 0715ba5..397d4c3 100644 --- a/src/__tests__/unit/agents.test.ts +++ b/src/__tests__/unit/agents.test.ts @@ -451,7 +451,7 @@ describe('GuardrailAgent', () => { expect(result.outputInfo.input).toBe('Latest user message with additional context.'); }); - it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => { + it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => { process.env.OPENAI_API_KEY = 'test'; const config = { version: 1, diff --git a/src/__tests__/unit/llm-base.test.ts b/src/__tests__/unit/llm-base.test.ts index 07684fe..522f63f 100644 --- a/src/__tests__/unit/llm-base.test.ts +++ b/src/__tests__/unit/llm-base.test.ts @@ -186,11 +186,11 @@ describe('LLM Base', () => { expect(result.info.flagged).toBe(false); expect(result.info.confidence).toBe(0.0); expect(result.info.error_message).toBe('LLM response validation failed.'); + // Token usage is now preserved even when schema validation fails expect(result.info.token_usage).toEqual({ - prompt_tokens: null, - completion_tokens: null, - total_tokens: null, - unavailable_reason: 'LLM call failed before usage could be recorded', + prompt_tokens: 12, + completion_tokens: 4, + total_tokens: 16, }); }); @@ -235,11 +235,11 @@ describe('LLM Base', () => { expect(result.info.flagged).toBe(false); expect(result.info.confidence).toBe(0.0); expect(result.info.error_message).toBe('LLM returned non-JSON or malformed JSON.'); + // Token usage is now preserved even when JSON parsing fails expect(result.info.token_usage).toEqual({ - prompt_tokens: null, - completion_tokens: null, - total_tokens: null, - unavailable_reason: 'LLM call failed before usage could be recorded', + prompt_tokens: 8, + completion_tokens: 3, + total_tokens: 11, }); }); }); diff --git a/src/checks/llm-base.ts b/src/checks/llm-base.ts index a528130..b778e36 100644 --- a/src/checks/llm-base.ts +++ b/src/checks/llm-base.ts @@ -291,6 +291,10 @@ export async function runLLM( unavailable_reason: 'LLM call failed before usage could be recorded', }); + // Declare tokenUsage outside try block so it's accessible in catch + // when JSON parsing or schema validation fails after a successful API call + let tokenUsage: TokenUsage = noUsage; + try { // Handle temperature based on model capabilities let temperature = 0.0; @@ -319,7 +323,8 @@ export async function runLLM( // @ts-ignore - safety_identifier is not in the OpenAI types yet const response = await client.chat.completions.create(params); - const tokenUsage = extractTokenUsage(response); + // Extract token usage immediately after API call so it's available even if parsing fails + tokenUsage = extractTokenUsage(response); const result = response.choices[0]?.message?.content; if (!result) { return [ @@ -356,6 +361,7 @@ export async function runLLM( } // Fail-open on JSON parsing errors (malformed or non-JSON responses) + // Use tokenUsage here since API call succeeded but response parsing failed if (error instanceof SyntaxError || (error as Error)?.constructor?.name === 'SyntaxError') { console.warn('LLM returned non-JSON or malformed JSON.', error); return [ @@ -366,11 +372,12 @@ export async function runLLM( error_message: 'LLM returned non-JSON or malformed JSON.', }, }), - noUsage, + tokenUsage, ]; } // Fail-open on schema validation errors (e.g., wrong types like confidence as string) + // Use tokenUsage here since API call succeeded but schema validation failed if (error instanceof z.ZodError) { console.warn('LLM response validation failed.', error); return [ @@ -382,7 +389,7 @@ export async function runLLM( zod_issues: error.issues ?? [], }, }), - noUsage, + tokenUsage, ]; } From c42690497263c08c02525d554c6be12b0bf92365 Mon Sep 17 00:00:00 2001 From: Steven C Date: Tue, 2 Dec 2025 12:40:31 -0500 Subject: [PATCH 3/3] Changed undefined typing --- src/checks/llm-base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/checks/llm-base.ts b/src/checks/llm-base.ts index b778e36..d6ac356 100644 --- a/src/checks/llm-base.ts +++ b/src/checks/llm-base.ts @@ -455,7 +455,7 @@ export function createLLMCheckFn( ); if (isLLMErrorOutput(analysis)) { - return createErrorResult(name, analysis, undefined, tokenUsage); + return createErrorResult(name, analysis, {}, tokenUsage); } const isTrigger = analysis.flagged && analysis.confidence >= config.confidence_threshold;