diff --git a/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts index 41d035564a..e2a18b2c28 100644 --- a/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts +++ b/packages/sdk/server-ai/__tests__/LDAIClientImpl.test.ts @@ -57,6 +57,7 @@ it('returns config with interpolated messagess', async () => { ], tracker: expect.any(Object), enabled: true, + toVercelAISDK: expect.any(Function), }); }); @@ -102,6 +103,7 @@ it('handles missing metadata in variation', async () => { messages: [{ role: 'system', content: 'Hello' }], tracker: expect.any(Object), enabled: false, + toVercelAISDK: expect.any(Function), }); }); @@ -125,6 +127,7 @@ it('passes the default value to the underlying client', async () => { provider: defaultValue.provider, tracker: expect.any(Object), enabled: false, + toVercelAISDK: expect.any(Function), }); expect(mockLdClient.variation).toHaveBeenCalledWith(key, testContext, defaultValue); diff --git a/packages/sdk/server-ai/__tests__/LDAIConfigMapper.test.ts b/packages/sdk/server-ai/__tests__/LDAIConfigMapper.test.ts new file mode 100644 index 0000000000..ddee1b26f2 --- /dev/null +++ b/packages/sdk/server-ai/__tests__/LDAIConfigMapper.test.ts @@ -0,0 +1,159 @@ +import { LDMessage, VercelAISDKMapOptions } from '../src/api/config'; +import { LDAIConfigMapper } from '../src/LDAIConfigMapper'; + +describe('_findParameter', () => { + it('handles undefined model and messages', () => { + const mapper = new LDAIConfigMapper(); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param')).toBeUndefined(); + }); + + it('handles parameter not found', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + parameters: { + 'test-param': 123, + }, + custom: { + 'test-param': 456, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('other-param')).toBeUndefined(); + }); + + it('finds parameter from single model parameter', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + parameters: { + 'test-param': 123, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param')).toEqual(123); + }); + + it('finds parameter from multiple model parameters', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + parameters: { + testParam: 123, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param', 'testParam')).toEqual(123); + }); + + it('finds parameter from single model custom parameter', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + custom: { + 'test-param': 123, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param')).toEqual(123); + }); + + it('finds parameter from multiple model custom parameters', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + custom: { + testParam: 123, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param', 'testParam')).toEqual(123); + }); + + it('gives precedence to model parameters over model custom parameters', () => { + const mapper = new LDAIConfigMapper({ + name: 'test-ai-model', + parameters: { + 'test-param': 123, + }, + custom: { + 'test-param': 456, + }, + }); + // eslint-disable-next-line @typescript-eslint/dot-notation + expect(mapper['_findParameter']('test-param', 'testParam')).toEqual(123); + }); +}); + +describe('toVercelAIAISDK', () => { + const mockModel = { name: 'mockModel' }; + const mockMessages: LDMessage[] = [ + { role: 'user', content: 'test prompt' }, + { role: 'system', content: 'test instruction' }, + ]; + const mockOptions: VercelAISDKMapOptions = { + nonInterpolatedMessages: [{ role: 'assistant', content: 'test assistant instruction' }], + }; + const mockProvider = jest.fn().mockReturnValue(mockModel); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('handles undefined model and messages', () => { + const mapper = new LDAIConfigMapper(); + const result = mapper.toVercelAISDK(mockProvider); + + expect(mockProvider).toHaveBeenCalledWith(''); + expect(result).toEqual( + expect.objectContaining({ + model: mockModel, + messages: undefined, + }), + ); + }); + + it('uses additional messages', () => { + const mapper = new LDAIConfigMapper({ name: 'test-ai-model' }); + const result = mapper.toVercelAISDK(mockProvider, mockOptions); + + expect(mockProvider).toHaveBeenCalledWith('test-ai-model'); + expect(result).toEqual( + expect.objectContaining({ + model: mockModel, + messages: mockOptions.nonInterpolatedMessages, + }), + ); + }); + + it('combines config messages and additional messages', () => { + const mapper = new LDAIConfigMapper({ name: 'test-ai-model' }, undefined, mockMessages); + const result = mapper.toVercelAISDK(mockProvider, mockOptions); + + expect(mockProvider).toHaveBeenCalledWith('test-ai-model'); + expect(result).toEqual( + expect.objectContaining({ + model: mockModel, + messages: [...mockMessages, ...(mockOptions.nonInterpolatedMessages ?? [])], + }), + ); + }); + + it('requests parameters correctly', () => { + const mapper = new LDAIConfigMapper({ name: 'test-ai-model' }, undefined, mockMessages); + const findParameterMock = jest.spyOn(mapper as any, '_findParameter'); + const result = mapper.toVercelAISDK(mockProvider); + + expect(mockProvider).toHaveBeenCalledWith('test-ai-model'); + expect(result).toEqual( + expect.objectContaining({ + model: mockModel, + messages: mockMessages, + }), + ); + expect(findParameterMock).toHaveBeenCalledWith('max_tokens', 'maxTokens'); + expect(findParameterMock).toHaveBeenCalledWith('temperature'); + expect(findParameterMock).toHaveBeenCalledWith('top_p', 'topP'); + expect(findParameterMock).toHaveBeenCalledWith('top_k', 'topK'); + expect(findParameterMock).toHaveBeenCalledWith('presence_penalty', 'presencePenalty'); + expect(findParameterMock).toHaveBeenCalledWith('frequency_penalty', 'frequencyPenalty'); + expect(findParameterMock).toHaveBeenCalledWith('stop', 'stop_sequences', 'stopSequences'); + expect(findParameterMock).toHaveBeenCalledWith('seed'); + }); +}); diff --git a/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts b/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts index 483d097889..9dd28574cc 100644 --- a/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts +++ b/packages/sdk/server-ai/__tests__/LDAIConfigTrackerImpl.test.ts @@ -129,6 +129,13 @@ it('tracks success', () => { { configKey, variationKey, version }, 1, ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); }); it('tracks OpenAI usage', async () => { @@ -167,6 +174,20 @@ it('tracks OpenAI usage', async () => { 1, ); + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:generation:error', + expect.anything(), + expect.anything(), + expect.anything(), + ); + expect(mockTrack).toHaveBeenCalledWith( '$ld:ai:tokens:total', testContext, @@ -226,6 +247,13 @@ it('tracks error when OpenAI metrics function throws', async () => { { configKey, variationKey, version }, 1, ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); }); it('tracks Bedrock conversation with successful response', () => { @@ -260,6 +288,20 @@ it('tracks Bedrock conversation with successful response', () => { 1, ); + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:generation:error', + expect.anything(), + expect.anything(), + expect.anything(), + ); + expect(mockTrack).toHaveBeenCalledWith( '$ld:ai:duration:total', testContext, @@ -318,6 +360,409 @@ it('tracks Bedrock conversation with error response', () => { { configKey, variationKey, version }, 1, ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); +}); + +describe('Vercel AI SDK generateText', () => { + it('tracks Vercel AI SDK usage', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const TOTAL_TOKENS = 100; + const PROMPT_TOKENS = 49; + const COMPLETION_TOKENS = 51; + + await tracker.trackVercelAISDKGenerateTextMetrics(async () => ({ + usage: { + totalTokens: TOTAL_TOKENS, + promptTokens: PROMPT_TOKENS, + completionTokens: COMPLETION_TOKENS, + }, + })); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:generation:error', + expect.anything(), + expect.anything(), + expect.anything(), + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:total', + testContext, + { configKey, variationKey, version }, + TOTAL_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, variationKey, version }, + PROMPT_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:output', + testContext, + { configKey, variationKey, version }, + COMPLETION_TOKENS, + ); + }); + + it('tracks error when Vercel AI SDK metrics function throws', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const error = new Error('Vercel AI SDK API error'); + await expect( + tracker.trackVercelAISDKGenerateTextMetrics(async () => { + throw error; + }), + ).rejects.toThrow(error); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:error', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); + }); +}); + +describe('Vercel AI SDK streamText', () => { + it('tracks Vercel AI SDK usage', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const TOTAL_TOKENS = 100; + const PROMPT_TOKENS = 49; + const COMPLETION_TOKENS = 51; + + let resolveDone: ((value: boolean) => void) | undefined; + const donePromise = new Promise((resolve) => { + resolveDone = resolve; + }); + + const finishReason = Promise.resolve('stop'); + jest + .spyOn(finishReason, 'then') + .mockImplementationOnce((fn) => finishReason.then(fn).finally(() => resolveDone?.(true))); + + tracker.trackVercelAISDKStreamTextMetrics(() => ({ + finishReason, + usage: Promise.resolve({ + totalTokens: TOTAL_TOKENS, + promptTokens: PROMPT_TOKENS, + completionTokens: COMPLETION_TOKENS, + }), + })); + + await donePromise; + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:generation:error', + expect.anything(), + expect.anything(), + expect.anything(), + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:total', + testContext, + { configKey, variationKey, version }, + TOTAL_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:input', + testContext, + { configKey, variationKey, version }, + PROMPT_TOKENS, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:tokens:output', + testContext, + { configKey, variationKey, version }, + COMPLETION_TOKENS, + ); + }); + + it('tracks error when Vercel AI SDK metrics function throws', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + const error = new Error('Vercel AI SDK API error'); + expect(() => + tracker.trackVercelAISDKStreamTextMetrics(() => { + throw error; + }), + ).toThrow(error); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:error', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); + }); + + it('tracks error when Vercel AI SDK finishes because of an error', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + tracker.trackVercelAISDKStreamTextMetrics(() => ({ + finishReason: Promise.resolve('error'), + })); + + await new Promise(process.nextTick); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:error', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); + }); + + it('tracks error when Vercel AI SDK finishReason promise rejects', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + tracker.trackVercelAISDKStreamTextMetrics(() => ({ + finishReason: Promise.reject(new Error('Vercel AI SDK API error')), + })); + + await new Promise(process.nextTick); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:error', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); + }); + + it('squashes error when Vercel AI SDK usage promise rejects', async () => { + const tracker = new LDAIConfigTrackerImpl( + mockLdClient, + configKey, + variationKey, + version, + testContext, + ); + jest.spyOn(global.Date, 'now').mockReturnValueOnce(1000).mockReturnValueOnce(2000); + + tracker.trackVercelAISDKStreamTextMetrics(() => ({ + finishReason: Promise.resolve('stop'), + usage: Promise.reject(new Error('Vercel AI SDK API error')), + })); + + await new Promise(process.nextTick); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:duration:total', + testContext, + { configKey, variationKey, version }, + 1000, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).toHaveBeenCalledWith( + '$ld:ai:generation:success', + testContext, + { configKey, variationKey, version }, + 1, + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + '$ld:ai:generation:error', + expect.anything(), + expect.anything(), + expect.anything(), + ); + + expect(mockTrack).not.toHaveBeenCalledWith( + expect.stringMatching(/^\$ld:ai:tokens:/), + expect.anything(), + expect.anything(), + expect.anything(), + ); + }); }); it('tracks tokens', () => { diff --git a/packages/sdk/server-ai/__tests__/TokenUsage.test.ts b/packages/sdk/server-ai/__tests__/TokenUsage.test.ts index 3dbc8bf6b5..4edaa7f237 100644 --- a/packages/sdk/server-ai/__tests__/TokenUsage.test.ts +++ b/packages/sdk/server-ai/__tests__/TokenUsage.test.ts @@ -1,5 +1,8 @@ -import { createBedrockTokenUsage } from '../src/api/metrics/BedrockTokenUsage'; -import { createOpenAiUsage } from '../src/api/metrics/OpenAiUsage'; +import { + createBedrockTokenUsage, + createOpenAiUsage, + createVercelAISDKTokenUsage, +} from '../src/api/metrics'; it('createBedrockTokenUsage should create token usage with all values provided', () => { const usage = createBedrockTokenUsage({ @@ -76,3 +79,41 @@ it('createOpenAiUsage should handle explicitly undefined values', () => { output: 0, }); }); + +it('createVercelAISDKTokenUsage should create token usage with all values provided', () => { + const usage = createVercelAISDKTokenUsage({ + totalTokens: 100, + promptTokens: 40, + completionTokens: 60, + }); + + expect(usage).toEqual({ + total: 100, + input: 40, + output: 60, + }); +}); + +it('createVercelAISDKTokenUsage should default to 0 for missing values', () => { + const usage = createVercelAISDKTokenUsage({}); + + expect(usage).toEqual({ + total: 0, + input: 0, + output: 0, + }); +}); + +it('createVercelAISDKTokenUsage should handle explicitly undefined values', () => { + const usage = createVercelAISDKTokenUsage({ + totalTokens: undefined, + promptTokens: 40, + completionTokens: undefined, + }); + + expect(usage).toEqual({ + total: 0, + input: 40, + output: 0, + }); +}); diff --git a/packages/sdk/server-ai/src/LDAIClientImpl.ts b/packages/sdk/server-ai/src/LDAIClientImpl.ts index bca8431cce..de2079053e 100644 --- a/packages/sdk/server-ai/src/LDAIClientImpl.ts +++ b/packages/sdk/server-ai/src/LDAIClientImpl.ts @@ -2,8 +2,18 @@ import * as Mustache from 'mustache'; import { LDContext } from '@launchdarkly/js-server-sdk-common'; -import { LDAIConfig, LDAIDefaults, LDMessage, LDModelConfig, LDProviderConfig } from './api/config'; +import { + LDAIConfig, + LDAIDefaults, + LDMessage, + LDModelConfig, + LDProviderConfig, + VercelAISDKConfig, + VercelAISDKMapOptions, + VercelAISDKProvider, +} from './api/config'; import { LDAIClient } from './api/LDAIClient'; +import { LDAIConfigMapper } from './LDAIConfigMapper'; import { LDAIConfigTrackerImpl } from './LDAIConfigTrackerImpl'; import { LDClientMin } from './LDClientMin'; @@ -52,7 +62,7 @@ export class LDAIClientImpl implements LDAIClient { ); // eslint-disable-next-line no-underscore-dangle const enabled = !!value._ldMeta?.enabled; - const config: LDAIConfig = { + const config: Omit = { tracker, enabled, }; @@ -73,6 +83,14 @@ export class LDAIClientImpl implements LDAIClient { })); } - return config; + const mapper = new LDAIConfigMapper(config.model, config.provider, config.messages); + + return { + ...config, + toVercelAISDK: ( + provider: VercelAISDKProvider | Record>, + options?: VercelAISDKMapOptions | undefined, + ): VercelAISDKConfig => mapper.toVercelAISDK(provider, options), + }; } } diff --git a/packages/sdk/server-ai/src/LDAIConfigMapper.ts b/packages/sdk/server-ai/src/LDAIConfigMapper.ts new file mode 100644 index 0000000000..4992dd21f0 --- /dev/null +++ b/packages/sdk/server-ai/src/LDAIConfigMapper.ts @@ -0,0 +1,64 @@ +import { + LDMessage, + LDModelConfig, + LDProviderConfig, + VercelAISDKConfig, + VercelAISDKMapOptions, + VercelAISDKProvider, +} from './api/config'; + +export class LDAIConfigMapper { + constructor( + private _model?: LDModelConfig | undefined, + private _provider?: LDProviderConfig | undefined, + private _messages?: LDMessage[] | undefined, + ) {} + + private _findParameter(...paramNames: string[]): T | undefined { + for (let i = 0; i < paramNames.length; i += 1) { + const paramName = paramNames[i]; + if (this._model?.parameters?.[paramName] !== undefined) { + return this._model?.parameters?.[paramName] as T; + } + if (this._model?.custom?.[paramName] !== undefined) { + return this._model?.custom?.[paramName] as T; + } + } + return undefined; + } + + toVercelAISDK( + provider: VercelAISDKProvider | Record>, + options?: VercelAISDKMapOptions | undefined, + ): VercelAISDKConfig { + let model: TMod | undefined; + if (typeof provider === 'function') { + model = provider(this._model?.name ?? ''); + } else { + model = provider[this._provider?.name ?? '']?.(this._model?.name ?? ''); + } + if (!model) { + throw new Error( + 'Vercel AI SDK model cannot be determined from the supplied provider parameter.', + ); + } + + let messages: LDMessage[] | undefined; + if (this._messages || options?.nonInterpolatedMessages) { + messages = [...(this._messages ?? []), ...(options?.nonInterpolatedMessages ?? [])]; + } + + return { + model, + messages, + maxTokens: this._findParameter('max_tokens', 'maxTokens'), + temperature: this._findParameter('temperature'), + topP: this._findParameter('top_p', 'topP'), + topK: this._findParameter('top_k', 'topK'), + presencePenalty: this._findParameter('presence_penalty', 'presencePenalty'), + frequencyPenalty: this._findParameter('frequency_penalty', 'frequencyPenalty'), + stopSequences: this._findParameter('stop', 'stop_sequences', 'stopSequences'), + seed: this._findParameter('seed'), + }; + } +} diff --git a/packages/sdk/server-ai/src/LDAIConfigTrackerImpl.ts b/packages/sdk/server-ai/src/LDAIConfigTrackerImpl.ts index 0972a5eee5..73c4bdfbd5 100644 --- a/packages/sdk/server-ai/src/LDAIConfigTrackerImpl.ts +++ b/packages/sdk/server-ai/src/LDAIConfigTrackerImpl.ts @@ -2,8 +2,13 @@ import { LDContext } from '@launchdarkly/js-server-sdk-common'; import { LDAIConfigTracker } from './api/config'; import { LDAIMetricSummary } from './api/config/LDAIConfigTracker'; -import { createBedrockTokenUsage, LDFeedbackKind, LDTokenUsage } from './api/metrics'; -import { createOpenAiUsage } from './api/metrics/OpenAiUsage'; +import { + createBedrockTokenUsage, + createOpenAiUsage, + createVercelAISDKTokenUsage, + LDFeedbackKind, + LDTokenUsage, +} from './api/metrics'; import { LDClientMin } from './LDClientMin'; export class LDAIConfigTrackerImpl implements LDAIConfigTracker { @@ -121,6 +126,72 @@ export class LDAIConfigTrackerImpl implements LDAIConfigTracker { return res; } + async trackVercelAISDKGenerateTextMetrics< + TRes extends { + usage?: { + totalTokens?: number; + promptTokens?: number; + completionTokens?: number; + }; + }, + >(func: () => Promise): Promise { + try { + const result = await this.trackDurationOf(func); + this.trackSuccess(); + if (result.usage) { + this.trackTokens(createVercelAISDKTokenUsage(result.usage)); + } + return result; + } catch (err) { + this.trackError(); + throw err; + } + } + + trackVercelAISDKStreamTextMetrics< + TRes extends { + finishReason?: Promise; + usage?: Promise<{ + totalTokens?: number; + promptTokens?: number; + completionTokens?: number; + }>; + }, + >(func: () => TRes): TRes { + const startTime = Date.now(); + try { + const result = func(); + result.finishReason + ?.then(async (finishReason) => { + const endTime = Date.now(); + this.trackDuration(endTime - startTime); + if (finishReason === 'error') { + this.trackError(); + } else { + this.trackSuccess(); + if (result.usage) { + try { + this.trackTokens(createVercelAISDKTokenUsage(await result.usage)); + } catch { + // Intentionally squashing this error + } + } + } + }) + .catch(() => { + const endTime = Date.now(); + this.trackDuration(endTime - startTime); + this.trackError(); + }); + return result; + } catch (err) { + const endTime = Date.now(); + this.trackDuration(endTime - startTime); + this.trackError(); + throw err; + } + } + trackTokens(tokens: LDTokenUsage): void { this._trackedMetrics.tokens = tokens; const trackData = this._getTrackData(); diff --git a/packages/sdk/server-ai/src/api/config/LDAIConfig.ts b/packages/sdk/server-ai/src/api/config/LDAIConfig.ts index 308ff02529..c5071a6d59 100644 --- a/packages/sdk/server-ai/src/api/config/LDAIConfig.ts +++ b/packages/sdk/server-ai/src/api/config/LDAIConfig.ts @@ -1,4 +1,5 @@ import { LDAIConfigTracker } from './LDAIConfigTracker'; +import { VercelAISDKConfig, VercelAISDKMapOptions, VercelAISDKProvider } from './VercelAISDK'; /** * Configuration related to the model. @@ -22,7 +23,7 @@ export interface LDModelConfig { export interface LDProviderConfig { /** - * The ID of the provider. + * The name of the provider. */ name: string; } @@ -68,13 +69,29 @@ export interface LDAIConfig { * Whether the configuration is enabled. */ enabled: boolean; + + /** + * Maps this AI config to a format usable direcly in Vercel AI SDK generateText() + * and streamText() methods. + * + * WARNING: this method can throw an exception if a Vercel AI SDK model cannot be determined. + * + * @param provider A Vercel AI SDK Provider or a map of provider names to Vercel AI SDK Providers. + * @param options Optional mapping options. + * @returns A configuration directly usable in Vercel AI SDK generateText() and streamText() + * @throws {Error} if a Vercel AI SDK model cannot be determined from the given provider parameter. + */ + toVercelAISDK: ( + provider: VercelAISDKProvider | Record>, + options?: VercelAISDKMapOptions | undefined, + ) => VercelAISDKConfig; } /** * Default value for a `modelConfig`. This is the same as the LDAIConfig, but it does not include - * a tracker and `enabled` is optional. + * a tracker or mapper, and `enabled` is optional. */ -export type LDAIDefaults = Omit & { +export type LDAIDefaults = Omit & { /** * Whether the configuration is enabled. * diff --git a/packages/sdk/server-ai/src/api/config/LDAIConfigTracker.ts b/packages/sdk/server-ai/src/api/config/LDAIConfigTracker.ts index 2f92aa9386..dfed0fa4db 100644 --- a/packages/sdk/server-ai/src/api/config/LDAIConfigTracker.ts +++ b/packages/sdk/server-ai/src/api/config/LDAIConfigTracker.ts @@ -133,6 +133,55 @@ export interface LDAIConfigTracker { res: TRes, ): TRes; + /** + * Track a Vercel AI SDK generateText operation. + * + * This function will track the duration of the operation, the token usage, and the success or error status. + * + * If the provided function throws, then this method will also throw. + * In the case the provided function throws, this function will record the duration and an error. + * A failed operation will not have any token usage data. + * + * @param func Function which executes the operation. + * @returns The result of the operation. + */ + trackVercelAISDKGenerateTextMetrics< + TRes extends { + usage?: { + totalTokens?: number; + promptTokens?: number; + completionTokens?: number; + }; + }, + >( + func: () => Promise, + ): Promise; + + /** + * Track a Vercel AI SDK streamText operation. + * + * This function will track the duration of the operation, the token usage, and the success or error status. + * + * If the provided function throws, then this method will also throw. + * In the case the provided function throws, this function will record the duration and an error. + * A failed operation will not have any token usage data. + * + * @param func Function which executes the operation. + * @returns The result of the operation. + */ + trackVercelAISDKStreamTextMetrics< + TRes extends { + finishReason?: Promise; + usage?: Promise<{ + totalTokens?: number; + promptTokens?: number; + completionTokens?: number; + }>; + }, + >( + func: () => TRes, + ): TRes; + /** * Get a summary of the tracked metrics. */ diff --git a/packages/sdk/server-ai/src/api/config/VercelAISDK.ts b/packages/sdk/server-ai/src/api/config/VercelAISDK.ts new file mode 100644 index 0000000000..4387fba06d --- /dev/null +++ b/packages/sdk/server-ai/src/api/config/VercelAISDK.ts @@ -0,0 +1,20 @@ +import { type LDMessage } from './LDAIConfig'; + +export type VercelAISDKProvider = (modelName: string) => TMod; + +export interface VercelAISDKMapOptions { + nonInterpolatedMessages?: LDMessage[] | undefined; +} + +export interface VercelAISDKConfig { + model: TMod; + messages?: LDMessage[] | undefined; + maxTokens?: number | undefined; + temperature?: number | undefined; + topP?: number | undefined; + topK?: number | undefined; + presencePenalty?: number | undefined; + frequencyPenalty?: number | undefined; + stopSequences?: string[] | undefined; + seed?: number | undefined; +} diff --git a/packages/sdk/server-ai/src/api/config/index.ts b/packages/sdk/server-ai/src/api/config/index.ts index 1c07d5c3a4..a3f3752908 100644 --- a/packages/sdk/server-ai/src/api/config/index.ts +++ b/packages/sdk/server-ai/src/api/config/index.ts @@ -1,2 +1,3 @@ export * from './LDAIConfig'; +export * from './VercelAISDK'; export { LDAIConfigTracker } from './LDAIConfigTracker'; diff --git a/packages/sdk/server-ai/src/api/metrics/VercelAISDKTokenUsage.ts b/packages/sdk/server-ai/src/api/metrics/VercelAISDKTokenUsage.ts new file mode 100644 index 0000000000..dbe83a8bf4 --- /dev/null +++ b/packages/sdk/server-ai/src/api/metrics/VercelAISDKTokenUsage.ts @@ -0,0 +1,13 @@ +import { LDTokenUsage } from './LDTokenUsage'; + +export function createVercelAISDKTokenUsage(data: { + totalTokens?: number; + promptTokens?: number; + completionTokens?: number; +}): LDTokenUsage { + return { + total: data.totalTokens ?? 0, + input: data.promptTokens ?? 0, + output: data.completionTokens ?? 0, + }; +} diff --git a/packages/sdk/server-ai/src/api/metrics/index.ts b/packages/sdk/server-ai/src/api/metrics/index.ts index 9f5e199f59..157fbd593c 100644 --- a/packages/sdk/server-ai/src/api/metrics/index.ts +++ b/packages/sdk/server-ai/src/api/metrics/index.ts @@ -1,3 +1,5 @@ export * from './BedrockTokenUsage'; +export * from './OpenAiUsage'; export * from './LDFeedbackKind'; export * from './LDTokenUsage'; +export * from './VercelAISDKTokenUsage';