diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 6d628ddfdf..c8e8aadb06 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -25,6 +25,7 @@ import { vscodeLlmModels, xaiModels, internationalZAiModels, + watsonxModels, } from "./providers/index.js" /** @@ -68,6 +69,7 @@ export const providerNames = [ "io-intelligence", "roo", "vercel-ai-gateway", + "watsonx", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -343,6 +345,13 @@ const vercelAiGatewaySchema = baseProviderSettingsSchema.extend({ vercelAiGatewayModelId: z.string().optional(), }) +const watsonxSchema = apiModelIdProviderModelSchema.extend({ + watsonxApiKey: z.string().optional(), + watsonxProjectId: z.string().optional(), + watsonxBaseUrl: z.string().optional(), + watsonxRegion: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -384,6 +393,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })), rooSchema.merge(z.object({ apiProvider: z.literal("roo") })), vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })), + watsonxSchema.merge(z.object({ apiProvider: z.literal("watsonx") })), defaultSchema, ]) @@ -425,6 +435,7 @@ export const providerSettingsSchema = z.object({ ...qwenCodeSchema.shape, ...rooSchema.shape, ...vercelAiGatewaySchema.shape, + ...watsonxSchema.shape, ...codebaseIndexProviderSchema.shape, }) @@ -578,6 +589,7 @@ export const MODELS_BY_PROVIDER: Record< unbound: { id: "unbound", label: "Unbound", models: [] }, deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] }, "vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] }, + watsonx: { id: "watsonx", label: "IBM watsonx", models: Object.keys(watsonxModels) }, } export const dynamicProviders = [ @@ -589,6 +601,7 @@ export const dynamicProviders = [ "unbound", "deepinfra", "vercel-ai-gateway", + "watsonx", ] as const satisfies readonly ProviderName[] export type DynamicProvider = (typeof dynamicProviders)[number] diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 21e43aaa99..c9b797757c 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -30,3 +30,4 @@ export * from "./xai.js" export * from "./vercel-ai-gateway.js" export * from "./zai.js" export * from "./deepinfra.js" +export * from "./watsonx.js" diff --git a/packages/types/src/providers/watsonx.ts b/packages/types/src/providers/watsonx.ts new file mode 100644 index 0000000000..758e4d4a5a --- /dev/null +++ b/packages/types/src/providers/watsonx.ts @@ -0,0 +1,148 @@ +import type { ModelInfo } from "../model.js" + +// IBM watsonx.ai models +// https://www.ibm.com/products/watsonx-ai +export type WatsonxModelId = keyof typeof watsonxModels + +export const watsonxDefaultModelId: WatsonxModelId = "ibm/granite-3-8b-instruct" + +export const watsonxModels = { + // Granite models - IBM's foundation models + "ibm/granite-3-8b-instruct": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0002, + outputPrice: 0.0006, + description: "IBM Granite 3.0 8B Instruct - Optimized for enterprise tasks", + }, + "ibm/granite-3-2b-instruct": { + maxTokens: 4096, + contextWindow: 4096, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0001, + outputPrice: 0.0003, + description: "IBM Granite 3.0 2B Instruct - Lightweight model for simple tasks", + }, + "ibm/granite-20b-multilingual": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0006, + outputPrice: 0.0018, + description: "IBM Granite 20B Multilingual - Supports multiple languages", + }, + "ibm/granite-13b-chat-v2": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0004, + outputPrice: 0.0012, + description: "IBM Granite 13B Chat v2 - Optimized for conversational AI", + }, + "ibm/granite-13b-instruct-v2": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0004, + outputPrice: 0.0012, + description: "IBM Granite 13B Instruct v2 - General purpose instruction following", + }, + "ibm/granite-7b-lab": { + maxTokens: 4096, + contextWindow: 4096, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0002, + outputPrice: 0.0006, + description: "IBM Granite 7B Lab - Experimental model for research", + }, + // Granite Code models - specialized for code generation + "ibm/granite-34b-code-instruct": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.001, + outputPrice: 0.003, + description: "IBM Granite 34B Code Instruct - Specialized for code generation and understanding", + }, + "ibm/granite-20b-code-instruct": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0006, + outputPrice: 0.0018, + description: "IBM Granite 20B Code Instruct - Code generation model", + }, + "ibm/granite-8b-code-instruct": { + maxTokens: 4096, + contextWindow: 4096, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0002, + outputPrice: 0.0006, + description: "IBM Granite 8B Code Instruct - Lightweight code model", + }, + "ibm/granite-3b-code-instruct": { + maxTokens: 2048, + contextWindow: 2048, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0001, + outputPrice: 0.0003, + description: "IBM Granite 3B Code Instruct - Fast code completion", + }, + // Third-party models available on watsonx + "meta-llama/llama-3-70b-instruct": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0029, + outputPrice: 0.0087, + description: "Meta Llama 3 70B Instruct on watsonx", + }, + "meta-llama/llama-3-8b-instruct": { + maxTokens: 8192, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0002, + outputPrice: 0.0006, + description: "Meta Llama 3 8B Instruct on watsonx", + }, + "mistralai/mixtral-8x7b-instruct-v01": { + maxTokens: 4096, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.0005, + outputPrice: 0.0015, + description: "Mistral Mixtral 8x7B Instruct on watsonx", + }, + "mistralai/mistral-large": { + maxTokens: 8192, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.003, + outputPrice: 0.009, + description: "Mistral Large on watsonx", + }, +} as const satisfies Record + +export const watsonxModelInfoSaneDefaults: ModelInfo = { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, +} diff --git a/src/api/index.ts b/src/api/index.ts index ac00967676..35f9b81771 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -40,6 +40,7 @@ import { FeatherlessHandler, VercelAiGatewayHandler, DeepInfraHandler, + WatsonxHandler, } from "./providers" import { NativeOllamaHandler } from "./providers/native-ollama" @@ -165,6 +166,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new FeatherlessHandler(options) case "vercel-ai-gateway": return new VercelAiGatewayHandler(options) + case "watsonx": + return new WatsonxHandler(options) default: apiProvider satisfies "gemini-cli" | undefined return new AnthropicHandler(options) diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 85d877b6bc..5509591395 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -34,3 +34,4 @@ export { RooHandler } from "./roo" export { FeatherlessHandler } from "./featherless" export { VercelAiGatewayHandler } from "./vercel-ai-gateway" export { DeepInfraHandler } from "./deepinfra" +export { WatsonxHandler } from "./watsonx" diff --git a/src/api/providers/watsonx.ts b/src/api/providers/watsonx.ts new file mode 100644 index 0000000000..fba146c326 --- /dev/null +++ b/src/api/providers/watsonx.ts @@ -0,0 +1,168 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { type ModelInfo, type WatsonxModelId, watsonxDefaultModelId, watsonxModels } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" + +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { getApiRequestTimeout } from "./utils/timeout-config" +import { handleOpenAIError } from "./utils/openai-error-handler" + +export class WatsonxHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: OpenAI + private readonly providerName = "IBM watsonx" + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + + // Construct the base URL for watsonx API + // Default to US South region if not specified + const region = this.options.watsonxRegion || "us-south" + const baseURL = this.options.watsonxBaseUrl || `https://${region}.ml.cloud.ibm.com/ml/v1` + const apiKey = this.options.watsonxApiKey || "not-provided" + + const headers = { + ...DEFAULT_HEADERS, + "X-Watson-Project-Id": this.options.watsonxProjectId || "", + } + + const timeout = getApiRequestTimeout() + + this.client = new OpenAI({ + baseURL, + apiKey, + defaultHeaders: headers, + timeout, + }) + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: modelId, info: modelInfo } = this.getModel() + + // Combine system prompt with messages for watsonx format + const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { + role: "system", + content: systemPrompt, + } + + const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + temperature: this.options.modelTemperature ?? 0.7, + messages: convertedMessages, + stream: true as const, + stream_options: { include_usage: true }, + } + + // Add max_tokens if needed + if (this.options.includeMaxTokens === true) { + requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens + } + + let stream + try { + stream = await this.client.chat.completions.create(requestOptions) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } + + let lastUsage + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta ?? {} + + if (delta.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (chunk.usage) { + lastUsage = chunk.usage + } + } + + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, modelInfo) + } + } + + protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage?.prompt_tokens || 0, + outputTokens: usage?.completion_tokens || 0, + } + } + + override getModel() { + const modelId = this.options.apiModelId + let id = modelId && modelId in watsonxModels ? (modelId as WatsonxModelId) : watsonxDefaultModelId + let info: ModelInfo = watsonxModels[id] + const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + return { id, info, ...params } + } + + async completePrompt(prompt: string): Promise { + try { + const model = this.getModel() + const modelInfo = model.info + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: model.id, + messages: [{ role: "user", content: prompt }], + } + + // Add max_tokens if needed + if (this.options.includeMaxTokens === true) { + requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens + } + + let response + try { + response = await this.client.chat.completions.create(requestOptions) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } + + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`${this.providerName} completion error: ${error.message}`) + } + + throw error + } + } +} + +/** + * Helper function to get available watsonx models. + * + * Currently returns a static list of models defined in watsonxModels. + * IBM watsonx doesn't provide a public API endpoint for dynamically listing available models, + * so we maintain a curated list of supported models that are known to work with the watsonx platform. + * + * @returns Array of available model IDs + */ +export async function getWatsonxModels(): Promise { + // Return the static list of supported watsonx models + // This list is maintained based on IBM's documentation and includes + // both IBM Granite models and third-party models available on watsonx + return Object.keys(watsonxModels) +} diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index f8a005e86a..d53722ffd0 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -57,6 +57,8 @@ import { vercelAiGatewayDefaultModelId, BEDROCK_CLAUDE_SONNET_4_MODEL_ID, deepInfraDefaultModelId, + watsonxDefaultModelId, + watsonxModels, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" @@ -348,11 +350,16 @@ function getSelectedModel({ const info = routerModels["vercel-ai-gateway"]?.[id] return { id, info } } + case "watsonx": { + const id = apiConfiguration.apiModelId ?? watsonxDefaultModelId + const info = watsonxModels[id as keyof typeof watsonxModels] + return { id, info } + } // case "anthropic": // case "human-relay": // case "fake-ai": default: { - provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai" + provider satisfies "anthropic" | "gemini-cli" | "human-relay" | "fake-ai" const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId const baseInfo = anthropicModels[id as keyof typeof anthropicModels]