diff --git a/.changeset/ai-tokens-source.md b/.changeset/ai-tokens-source.md new file mode 100644 index 0000000000..e3c9f9bb2d --- /dev/null +++ b/.changeset/ai-tokens-source.md @@ -0,0 +1,5 @@ +--- +'@posthog/ai': patch +--- + +Add `$ai_tokens_source` property ("sdk" or "passthrough") to all `$ai_generation` events to detect when token values are externally overridden via `posthogProperties` diff --git a/packages/ai/src/utils.ts b/packages/ai/src/utils.ts index 35390c01ff..a5a13e94ea 100644 --- a/packages/ai/src/utils.ts +++ b/packages/ai/src/utils.ts @@ -18,6 +18,22 @@ type EmbeddingCreateParams = OpenAIOrignal.EmbeddingCreateParams type TranscriptionCreateParams = OpenAIOrignal.Audio.Transcriptions.TranscriptionCreateParams type AnthropicTool = AnthropicOriginal.Tool +const TOKEN_PROPERTY_KEYS = new Set([ + '$ai_input_tokens', + '$ai_output_tokens', + '$ai_cache_read_input_tokens', + '$ai_cache_creation_input_tokens', + '$ai_total_tokens', + '$ai_reasoning_tokens', +]) + +export function getTokensSource(posthogProperties?: Record): string { + if (posthogProperties && Object.keys(posthogProperties).some((key) => TOKEN_PROPERTY_KEYS.has(key))) { + return 'passthrough' + } + return 'sdk' +} + // limit large outputs by truncating to 200kb (approx 200k bytes) export const MAX_OUTPUT_SIZE = 200000 const STRING_FORMAT = 'utf8' @@ -727,6 +743,7 @@ export const sendEventToPosthog = async ({ $ai_trace_id: traceId, $ai_base_url: baseURL, ...params.posthogProperties, + $ai_tokens_source: getTokensSource(params.posthogProperties), ...(distinctId ? {} : { $process_person_profile: false }), ...(tools ? { $ai_tools: tools } : {}), ...errorData, diff --git a/packages/ai/tests/anthropic.test.ts b/packages/ai/tests/anthropic.test.ts index 06f3caf6dd..4bfcc5f499 100644 --- a/packages/ai/tests/anthropic.test.ts +++ b/packages/ai/tests/anthropic.test.ts @@ -389,6 +389,25 @@ describe('PostHogAnthropic', () => { const [captureArgs] = captureMock.mock.calls const { properties } = captureArgs[0] expect(properties['$ai_usage']).toBeDefined() + expect(properties['$ai_tokens_source']).toBe('sdk') + }) + + conditionalTest('should set tokens_source to passthrough when token properties are overridden', async () => { + const response = await client.messages.create({ + model: 'claude-3-opus-20240229', + messages: [{ role: 'user', content: 'Hello Claude' }], + max_tokens: 100, + posthogDistinctId: 'test-user-123', + posthogProperties: { $ai_input_tokens: 99999 }, + }) + + expect(response).toEqual(mockResponse) + + const captureMock = mockPostHogClient.capture as jest.Mock + const [captureArgs] = captureMock.mock.calls + const { properties } = captureArgs[0] + expect(properties['$ai_tokens_source']).toBe('passthrough') + expect(properties['$ai_input_tokens']).toBe(99999) }) conditionalTest('should handle system prompts correctly', async () => { diff --git a/packages/ai/tests/utils.test.ts b/packages/ai/tests/utils.test.ts index 71a8c17d47..8c99372a0a 100644 --- a/packages/ai/tests/utils.test.ts +++ b/packages/ai/tests/utils.test.ts @@ -1,4 +1,21 @@ -import { toContentString } from '../src/utils' +import { toContentString, getTokensSource } from '../src/utils' + +describe('getTokensSource', () => { + it.each([ + ['undefined properties', undefined, 'sdk'], + ['empty properties', {}, 'sdk'], + ['unrelated properties', { foo: 'bar' }, 'sdk'], + ['$ai_input_tokens override', { $ai_input_tokens: 999 }, 'passthrough'], + ['$ai_output_tokens override', { $ai_output_tokens: 999 }, 'passthrough'], + ['$ai_total_tokens override', { $ai_total_tokens: 999 }, 'passthrough'], + ['$ai_cache_read_input_tokens override', { $ai_cache_read_input_tokens: 500 }, 'passthrough'], + ['$ai_cache_creation_input_tokens override', { $ai_cache_creation_input_tokens: 200 }, 'passthrough'], + ['$ai_reasoning_tokens override', { $ai_reasoning_tokens: 300 }, 'passthrough'], + ['mixed override and custom', { $ai_input_tokens: 999, custom_key: 'value' }, 'passthrough'], + ])('%s → %s', (_name, props, expected) => { + expect(getTokensSource(props)).toBe(expected) + }) +}) describe('toContentString', () => { describe('string inputs', () => {