diff --git a/packages/sdk/server-ai/src/LDAIClientImpl.ts b/packages/sdk/server-ai/src/LDAIClientImpl.ts index 0cb7f96c0d..c116e460fb 100644 --- a/packages/sdk/server-ai/src/LDAIClientImpl.ts +++ b/packages/sdk/server-ai/src/LDAIClientImpl.ts @@ -3,7 +3,7 @@ import * as Mustache from 'mustache'; import { LDContext, LDLogger } from '@launchdarkly/js-server-sdk-common'; import { LDAIAgent, LDAIAgentConfig, LDAIAgentDefaults } from './api/agents'; -import { TrackedChat, TrackedChatFactory } from './api/chat'; +import { SupportedAIProvider, TrackedChat, TrackedChatFactory } from './api/chat'; import { LDAIConfig, LDAIConfigTracker, @@ -233,6 +233,7 @@ export class LDAIClientImpl implements LDAIClient { context: LDContext, defaultValue: LDAIDefaults, variables?: Record, + defaultAiProvider?: SupportedAIProvider, ): Promise { // Track chat initialization this._ldClient.track('$ld:ai:config:function:initChat', context, key, 1); @@ -246,6 +247,6 @@ export class LDAIClientImpl implements LDAIClient { } // Create the TrackedChat instance based on the provider - return TrackedChatFactory.create(aiConfig, aiConfig.tracker, this._logger); + return TrackedChatFactory.create(aiConfig, aiConfig.tracker, this._logger, defaultAiProvider); } } diff --git a/packages/sdk/server-ai/src/api/LDAIClient.ts b/packages/sdk/server-ai/src/api/LDAIClient.ts index ccf3098938..d11801fcb1 100644 --- a/packages/sdk/server-ai/src/api/LDAIClient.ts +++ b/packages/sdk/server-ai/src/api/LDAIClient.ts @@ -1,7 +1,7 @@ import { LDContext } from '@launchdarkly/js-server-sdk-common'; import { LDAIAgent, LDAIAgentConfig, LDAIAgentDefaults } from './agents'; -import { TrackedChat } from './chat'; +import { SupportedAIProvider, TrackedChat } from './chat'; import { LDAIConfig, LDAIDefaults } from './config/LDAIConfig'; /** @@ -186,5 +186,6 @@ export interface LDAIClient { context: LDContext, defaultValue: LDAIDefaults, variables?: Record, + defaultAiProvider?: SupportedAIProvider, ): Promise; } diff --git a/packages/sdk/server-ai/src/api/chat/TrackedChatFactory.ts b/packages/sdk/server-ai/src/api/chat/TrackedChatFactory.ts index ea47625f78..4017a96417 100644 --- a/packages/sdk/server-ai/src/api/chat/TrackedChatFactory.ts +++ b/packages/sdk/server-ai/src/api/chat/TrackedChatFactory.ts @@ -5,6 +5,21 @@ import { LDAIConfigTracker } from '../config/LDAIConfigTracker'; import { AIProvider } from '../providers/AIProvider'; import { TrackedChat } from './TrackedChat'; +/** + * List of supported AI providers. + */ +export const SUPPORTED_AI_PROVIDERS = [ + 'openai', + // Multi-provider packages should be last in the list + 'langchain', + 'vercel', +] as const; + +/** + * Type representing the supported AI providers. + */ +export type SupportedAIProvider = (typeof SUPPORTED_AI_PROVIDERS)[number]; + /** * Factory for creating TrackedChat instances based on the provider configuration. */ @@ -17,13 +32,15 @@ export class TrackedChatFactory { * @param aiConfig The AI configuration * @param tracker The tracker for AI operations * @param logger Optional logger for logging provider initialization + * @param defaultAiProvider Optional default AI provider to use */ static async create( aiConfig: LDAIConfig, tracker: LDAIConfigTracker, logger?: LDLogger, + defaultAiProvider?: SupportedAIProvider, ): Promise { - const provider = await this._createAIProvider(aiConfig, logger); + const provider = await this._createAIProvider(aiConfig, logger, defaultAiProvider); if (!provider) { logger?.warn( `Provider is not supported or failed to initialize: ${aiConfig.provider?.name ?? 'unknown'}`, @@ -31,7 +48,6 @@ export class TrackedChatFactory { return undefined; } - logger?.debug(`Successfully created TrackedChat for provider: ${aiConfig.provider?.name}`); return new TrackedChat(aiConfig, tracker, provider); } @@ -42,53 +58,114 @@ export class TrackedChatFactory { private static async _createAIProvider( aiConfig: LDAIConfig, logger?: LDLogger, + defaultAiProvider?: SupportedAIProvider, ): Promise { const providerName = aiConfig.provider?.name?.toLowerCase(); - logger?.debug(`Attempting to create AI provider: ${providerName ?? 'unknown'}`); - let provider: AIProvider | undefined; + // Determine which providers to try based on defaultAiProvider + const providersToTry = this._getProvidersToTry(defaultAiProvider, providerName); - // Try specific implementations for the provider - switch (providerName) { - case 'openai': - // TODO: Return OpenAI AIProvider implementation when available - provider = undefined; - break; - case 'bedrock': - // TODO: Return Bedrock AIProvider implementation when available - provider = undefined; - break; - default: - provider = undefined; + // Try each provider in order + // eslint-disable-next-line no-restricted-syntax + for (const providerType of providersToTry) { + // eslint-disable-next-line no-await-in-loop + const provider = await this._tryCreateProvider(providerType, aiConfig, logger); + if (provider) { + return provider; + } } - // If no specific implementation worked, try the multi-provider packages - if (!provider) { - provider = await this._createLangChainProvider(aiConfig, logger); + return undefined; + } + + /** + * Determine which providers to try based on defaultAiProvider and providerName. + */ + private static _getProvidersToTry( + defaultAiProvider?: SupportedAIProvider, + providerName?: string, + ): SupportedAIProvider[] { + // If defaultAiProvider is set, only try that specific provider + if (defaultAiProvider) { + return [defaultAiProvider]; + } + + // If no defaultAiProvider is set, try all providers in order + const providerSet = new Set(); + + // First try the specific provider if it's supported + if (providerName && SUPPORTED_AI_PROVIDERS.includes(providerName as SupportedAIProvider)) { + providerSet.add(providerName as SupportedAIProvider); } - return provider; + // Then try multi-provider packages, but avoid duplicates + const multiProviderPackages: SupportedAIProvider[] = ['langchain', 'vercel']; + multiProviderPackages.forEach((provider) => { + providerSet.add(provider); + }); + + return Array.from(providerSet); } /** - * Create a LangChain AIProvider instance if the LangChain provider is available. + * Try to create a provider of the specified type. */ - private static async _createLangChainProvider( + private static async _tryCreateProvider( + providerType: SupportedAIProvider, + aiConfig: LDAIConfig, + logger?: LDLogger, + ): Promise { + switch (providerType) { + case 'openai': + return this._createProvider( + '@launchdarkly/server-sdk-ai-openai', + 'OpenAIProvider', + aiConfig, + logger, + ); + case 'langchain': + return this._createProvider( + '@launchdarkly/server-sdk-ai-langchain', + 'LangChainProvider', + aiConfig, + logger, + ); + case 'vercel': + return this._createProvider( + '@launchdarkly/server-sdk-ai-vercel', + 'VercelProvider', + aiConfig, + logger, + ); + default: + return undefined; + } + } + + /** + * Create a provider instance dynamically. + */ + private static async _createProvider( + packageName: string, + providerClassName: string, aiConfig: LDAIConfig, logger?: LDLogger, ): Promise { try { - logger?.debug('Attempting to load LangChain provider'); - // Try to dynamically import the LangChain provider - // This will work if @launchdarkly/server-sdk-ai-langchain is installed - // eslint-disable-next-line import/no-extraneous-dependencies, global-require - const { LangChainProvider } = require('@launchdarkly/server-sdk-ai-langchain'); - - const provider = await LangChainProvider.create(aiConfig, logger); - logger?.debug('Successfully created LangChain provider'); + // Try to dynamically import the provider + // This will work if the package is installed + // eslint-disable-next-line import/no-extraneous-dependencies, global-require, import/no-dynamic-require + const { [providerClassName]: ProviderClass } = require(packageName); + + const provider = await ProviderClass.create(aiConfig, logger); + logger?.debug( + `Successfully created AIProvider for: ${aiConfig.provider?.name} with package ${packageName}`, + ); return provider; } catch (error) { - // If the LangChain provider is not available or creation fails, return undefined - logger?.error(`Error creating LangChain provider: ${error}`); + // If the provider is not available or creation fails, return undefined + logger?.warn( + `Error creating AIProvider for: ${aiConfig.provider?.name} with package ${packageName}: ${error}`, + ); return undefined; } }