From 51f713f0e331a1cc4c6d87047679b452d48cf251 Mon Sep 17 00:00:00 2001 From: Thach Nguyen Date: Fri, 29 Aug 2025 11:55:19 -0700 Subject: [PATCH 1/3] feat(provider): add DeepInfra with dynamic model fetching & prompt-caching --- packages/types/src/global-settings.ts | 1 + packages/types/src/provider-settings.ts | 12 ++ packages/types/src/providers/deepinfra.ts | 14 ++ packages/types/src/providers/index.ts | 1 + src/api/index.ts | 3 + src/api/providers/deepinfra.ts | 147 ++++++++++++++++++ src/api/providers/fetchers/deepinfra.ts | 71 +++++++++ src/api/providers/fetchers/modelCache.ts | 4 + src/api/providers/index.ts | 1 + src/core/webview/webviewMessageHandler.ts | 9 ++ src/shared/ProfileValidator.ts | 2 + src/shared/api.ts | 2 + .../src/components/settings/ApiOptions.tsx | 18 +++ .../src/components/settings/ModelPicker.tsx | 1 + .../src/components/settings/constants.ts | 1 + .../settings/providers/DeepInfra.tsx | 94 +++++++++++ .../components/settings/providers/index.ts | 1 + .../components/ui/hooks/useSelectedModel.ts | 6 + .../src/utils/__tests__/validate.test.ts | 1 + webview-ui/src/utils/validate.ts | 10 ++ 20 files changed, 399 insertions(+) create mode 100644 packages/types/src/providers/deepinfra.ts create mode 100644 src/api/providers/deepinfra.ts create mode 100644 src/api/providers/fetchers/deepinfra.ts create mode 100644 webview-ui/src/components/settings/providers/DeepInfra.tsx diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 81c6ae6dfe..f1c4b81c48 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -192,6 +192,7 @@ export const SECRET_STATE_KEYS = [ "groqApiKey", "chutesApiKey", "litellmApiKey", + "deepInfraApiKey", "codeIndexOpenAiKey", "codeIndexQdrantApiKey", "codebaseIndexOpenAiCompatibleApiKey", diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 090dfe6693..d1831163f8 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -48,6 +48,7 @@ export const providerNames = [ "mistral", "moonshot", "deepseek", + "deepinfra", "doubao", "qwen-code", "unbound", @@ -236,6 +237,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({ deepSeekApiKey: z.string().optional(), }) +const deepInfraSchema = apiModelIdProviderModelSchema.extend({ + deepInfraBaseUrl: z.string().optional(), + deepInfraApiKey: z.string().optional(), + deepInfraModelId: z.string().optional(), +}) + const doubaoSchema = apiModelIdProviderModelSchema.extend({ doubaoBaseUrl: z.string().optional(), doubaoApiKey: z.string().optional(), @@ -349,6 +356,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), + deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })), doubaoSchema.merge(z.object({ apiProvider: z.literal("doubao") })), moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })), unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })), @@ -389,6 +397,7 @@ export const providerSettingsSchema = z.object({ ...openAiNativeSchema.shape, ...mistralSchema.shape, ...deepSeekSchema.shape, + ...deepInfraSchema.shape, ...doubaoSchema.shape, ...moonshotSchema.shape, ...unboundSchema.shape, @@ -438,6 +447,7 @@ export const MODEL_ID_KEYS: Partial[] = [ "huggingFaceModelId", "ioIntelligenceModelId", "vercelAiGatewayModelId", + "deepInfraModelId", ] export const getModelId = (settings: ProviderSettings): string | undefined => { @@ -559,6 +569,7 @@ export const MODELS_BY_PROVIDER: Record< openrouter: { id: "openrouter", label: "OpenRouter", models: [] }, requesty: { id: "requesty", label: "Requesty", models: [] }, unbound: { id: "unbound", label: "Unbound", models: [] }, + deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] }, "vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] }, } @@ -569,6 +580,7 @@ export const dynamicProviders = [ "openrouter", "requesty", "unbound", + "deepinfra", "vercel-ai-gateway", ] as const satisfies readonly ProviderName[] diff --git a/packages/types/src/providers/deepinfra.ts b/packages/types/src/providers/deepinfra.ts new file mode 100644 index 0000000000..9a430b3789 --- /dev/null +++ b/packages/types/src/providers/deepinfra.ts @@ -0,0 +1,14 @@ +import type { ModelInfo } from "../model.js" + +// Default fallback values for DeepInfra when model metadata is not yet loaded. +export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo" + +export const deepInfraDefaultModelInfo: ModelInfo = { + maxTokens: 16384, + contextWindow: 262144, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.3, + outputPrice: 1.2, + description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.", +} diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 97fa10ca82..21e43aaa99 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -29,3 +29,4 @@ export * from "./vscode-llm.js" export * from "./xai.js" export * from "./vercel-ai-gateway.js" export * from "./zai.js" +export * from "./deepinfra.js" diff --git a/src/api/index.ts b/src/api/index.ts index b50afbb023..ac00967676 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -39,6 +39,7 @@ import { RooHandler, FeatherlessHandler, VercelAiGatewayHandler, + DeepInfraHandler, } from "./providers" import { NativeOllamaHandler } from "./providers/native-ollama" @@ -138,6 +139,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new XAIHandler(options) case "groq": return new GroqHandler(options) + case "deepinfra": + return new DeepInfraHandler(options) case "huggingface": return new HuggingFaceHandler(options) case "chutes": diff --git a/src/api/providers/deepinfra.ts b/src/api/providers/deepinfra.ts new file mode 100644 index 0000000000..7cf018b069 --- /dev/null +++ b/src/api/providers/deepinfra.ts @@ -0,0 +1,147 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" +import { calculateApiCostOpenAI } from "../../shared/cost" + +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { RouterProvider } from "./router-provider" +import { getModelParams } from "../transform/model-params" +import { getModels } from "./fetchers/modelCache" + +export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler { + constructor(options: ApiHandlerOptions) { + super({ + options: { + ...options, + openAiHeaders: { + "X-Deepinfra-Source": "roo-code", + "X-Deepinfra-Version": `2025-08-25`, + }, + }, + name: "deepinfra", + baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`, + apiKey: options.deepInfraApiKey || "not-provided", + modelId: options.deepInfraModelId, + defaultModelId: deepInfraDefaultModelId, + defaultModelInfo: deepInfraDefaultModelInfo, + }) + } + + public override async fetchModel() { + this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL }) + return this.getModel() + } + + override getModel() { + const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId + const info = this.models[id] ?? deepInfraDefaultModelInfo + + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + }) + + return { id, info, ...params } + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + _metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + // Ensure we have up-to-date model metadata + await this.fetchModel() + const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel() + let prompt_cache_key = undefined + if (info.supportsPromptCache && _metadata?.taskId) { + prompt_cache_key = _metadata.taskId + } + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + reasoning_effort, + prompt_cache_key, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming + + if (this.supportsTemperature(modelId)) { + requestOptions.temperature = this.options.modelTemperature ?? 0 + } + + if (this.options.includeMaxTokens === true && info.maxTokens) { + ;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens + } + + const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse() + + let lastUsage: OpenAI.CompletionUsage | undefined + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { type: "text", text: delta.content } + } + + if (delta && "reasoning_content" in delta && delta.reasoning_content) { + yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" } + } + + if (chunk.usage) { + lastUsage = chunk.usage + } + } + + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, info) + } + } + + async completePrompt(prompt: string): Promise { + await this.fetchModel() + const { id: modelId, info } = this.getModel() + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + messages: [{ role: "user", content: prompt }], + } + if (this.supportsTemperature(modelId)) { + requestOptions.temperature = this.options.modelTemperature ?? 0 + } + if (this.options.includeMaxTokens === true && info.maxTokens) { + ;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens + } + + const resp = await this.client.chat.completions.create(requestOptions) + return resp.choices[0]?.message?.content || "" + } + + protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk { + const inputTokens = usage?.prompt_tokens || 0 + const outputTokens = usage?.completion_tokens || 0 + const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0 + const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 + + const totalCost = modelInfo + ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + : 0 + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens: cacheWriteTokens || undefined, + cacheReadTokens: cacheReadTokens || undefined, + totalCost, + } + } +} diff --git a/src/api/providers/fetchers/deepinfra.ts b/src/api/providers/fetchers/deepinfra.ts new file mode 100644 index 0000000000..f38daff822 --- /dev/null +++ b/src/api/providers/fetchers/deepinfra.ts @@ -0,0 +1,71 @@ +import axios from "axios" +import { z } from "zod" + +import { type ModelInfo } from "@roo-code/types" + +import { DEFAULT_HEADERS } from "../constants" + +// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object. + +const DeepInfraModelSchema = z.object({ + id: z.string(), + object: z.literal("model").optional(), + owned_by: z.string().optional(), + created: z.number().optional(), + root: z.string().optional(), + metadata: z + .object({ + description: z.string().optional(), + context_length: z.number().optional(), + max_tokens: z.number().optional(), + tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"] + pricing: z + .object({ + input_tokens: z.number().optional(), + output_tokens: z.number().optional(), + cache_read_tokens: z.number().optional(), + }) + .optional(), + }) + .optional(), +}) + +const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) }) + +export async function getDeepInfraModels( + apiKey?: string, + baseUrl: string = "https://api.deepinfra.com/v1/openai", +): Promise> { + const headers: Record = { ...DEFAULT_HEADERS } + if (apiKey) headers["Authorization"] = `Bearer ${apiKey}` + + const url = `${baseUrl.replace(/\/$/, "")}/models` + const models: Record = {} + + const response = await axios.get(url, { headers }) + const parsed = DeepInfraModelsResponseSchema.safeParse(response.data) + const data = parsed.success ? parsed.data.data : response.data?.data || [] + + for (const m of data as Array>) { + const meta = m.metadata || {} + const tags = meta.tags || [] + + const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192 + const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2) + + const info: ModelInfo = { + maxTokens, + contextWindow, + supportsImages: tags.includes("vision"), + supportsPromptCache: tags.includes("prompt_cache"), + inputPrice: meta.pricing?.input_tokens, + outputPrice: meta.pricing?.output_tokens, + cacheReadsPrice: meta.pricing?.cache_read_tokens, + description: meta.description, + } + + models[m.id] = info + } + + return models +} diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 0005e8205f..a91cdaf994 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api" import { getOllamaModels } from "./ollama" import { getLMStudioModels } from "./lmstudio" import { getIOIntelligenceModels } from "./io-intelligence" +import { getDeepInfraModels } from "./deepinfra" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -79,6 +80,9 @@ export const getModels = async (options: GetModelsOptions): Promise case "lmstudio": models = await getLMStudioModels(options.baseUrl) break + case "deepinfra": + models = await getDeepInfraModels(options.apiKey, options.baseUrl) + break case "io-intelligence": models = await getIOIntelligenceModels(options.apiKey) break diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index c3786c5f56..85d877b6bc 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -33,3 +33,4 @@ export { FireworksHandler } from "./fireworks" export { RooHandler } from "./roo" export { FeatherlessHandler } from "./featherless" export { VercelAiGatewayHandler } from "./vercel-ai-gateway" +export { DeepInfraHandler } from "./deepinfra" diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a495489cc1..bd842a08b1 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -550,6 +550,7 @@ export const webviewMessageHandler = async ( litellm: {}, ollama: {}, lmstudio: {}, + deepinfra: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -577,6 +578,14 @@ export const webviewMessageHandler = async ( { key: "glama", options: { provider: "glama" } }, { key: "unbound", options: { provider: "unbound", apiKey: apiConfiguration.unboundApiKey } }, { key: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } }, + { + key: "deepinfra", + options: { + provider: "deepinfra", + apiKey: apiConfiguration.deepInfraApiKey, + baseUrl: apiConfiguration.deepInfraBaseUrl, + }, + }, ] // Add IO Intelligence if API key is provided diff --git a/src/shared/ProfileValidator.ts b/src/shared/ProfileValidator.ts index 57c10301a2..78ff6ed9fe 100644 --- a/src/shared/ProfileValidator.ts +++ b/src/shared/ProfileValidator.ts @@ -90,6 +90,8 @@ export class ProfileValidator { return profile.requestyModelId case "io-intelligence": return profile.ioIntelligenceModelId + case "deepinfra": + return profile.deepInfraModelId case "human-relay": case "fake-ai": default: diff --git a/src/shared/api.ts b/src/shared/api.ts index 30dfd7393b..eb3ae124a8 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -27,6 +27,7 @@ const routerNames = [ "ollama", "lmstudio", "io-intelligence", + "deepinfra", "vercel-ai-gateway", ] as const @@ -151,5 +152,6 @@ export type GetModelsOptions = | { provider: "litellm"; apiKey: string; baseUrl: string } | { provider: "ollama"; baseUrl?: string } | { provider: "lmstudio"; baseUrl?: string } + | { provider: "deepinfra"; apiKey?: string; baseUrl?: string } | { provider: "io-intelligence"; apiKey: string } | { provider: "vercel-ai-gateway" } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 80ecd75ae4..32a0d452ea 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -36,6 +36,7 @@ import { ioIntelligenceDefaultModelId, rooDefaultModelId, vercelAiGatewayDefaultModelId, + deepInfraDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" @@ -93,6 +94,7 @@ import { Fireworks, Featherless, VercelAiGateway, + DeepInfra, } from "./providers" import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants" @@ -226,6 +228,8 @@ const ApiOptions = ({ vscode.postMessage({ type: "requestVsCodeLmModels" }) } else if (selectedProvider === "litellm") { vscode.postMessage({ type: "requestRouterModels" }) + } else if (selectedProvider === "deepinfra") { + vscode.postMessage({ type: "requestRouterModels" }) } }, 250, @@ -238,6 +242,8 @@ const ApiOptions = ({ apiConfiguration?.lmStudioBaseUrl, apiConfiguration?.litellmBaseUrl, apiConfiguration?.litellmApiKey, + apiConfiguration?.deepInfraApiKey, + apiConfiguration?.deepInfraBaseUrl, customHeaders, ], ) @@ -305,6 +311,7 @@ const ApiOptions = ({ } > > = { + deepinfra: { field: "deepInfraModelId", default: deepInfraDefaultModelId }, openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId }, glama: { field: "glamaModelId", default: glamaDefaultModelId }, unbound: { field: "unboundModelId", default: unboundDefaultModelId }, @@ -487,6 +494,17 @@ const ApiOptions = ({ /> )} + {selectedProvider === "deepinfra" && ( + + )} + {selectedProvider === "anthropic" && ( )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 949e0a081f..74e3d31f00 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -34,6 +34,7 @@ type ModelIdKey = keyof Pick< | "requestyModelId" | "openAiModelId" | "litellmModelId" + | "deepInfraModelId" | "ioIntelligenceModelId" | "vercelAiGatewayModelId" > diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index 9aa02bbf53..ae336730ff 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -48,6 +48,7 @@ export const MODELS_BY_PROVIDER: Partial void + routerModels?: RouterModels + refetchRouterModels: () => void + organizationAllowList: OrganizationAllowList + modelValidationError?: string +} + +export const DeepInfra = ({ + apiConfiguration, + setApiConfigurationField, + routerModels, + refetchRouterModels, + organizationAllowList, + modelValidationError, +}: DeepInfraProps) => { + const { t } = useAppTranslation() + + const [didRefetch, setDidRefetch] = useState() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + useEffect(() => { + // When base URL or API key changes, trigger a silent refresh of models + // The outer ApiOptions debounces and sends requestRouterModels; this keeps UI responsive + }, [apiConfiguration.deepInfraBaseUrl, apiConfiguration.deepInfraApiKey]) + + return ( + <> + + + + + + {didRefetch && ( +
+ {t("settings:providers.refreshModels.hint")} +
+ )} + + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index eedbba0c29..fe0e6cecf9 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -29,3 +29,4 @@ export { LiteLLM } from "./LiteLLM" export { Fireworks } from "./Fireworks" export { Featherless } from "./Featherless" export { VercelAiGateway } from "./VercelAiGateway" +export { DeepInfra } from "./DeepInfra" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index e9470e0902..b7fe4ff03d 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -56,6 +56,7 @@ import { qwenCodeModels, vercelAiGatewayDefaultModelId, BEDROCK_CLAUDE_SONNET_4_MODEL_ID, + deepInfraDefaultModelId, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" @@ -268,6 +269,11 @@ function getSelectedModel({ info: info || undefined, } } + case "deepinfra": { + const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId + const info = routerModels.deepinfra?.[id] + return { id, info } + } case "vscode-lm": { const id = apiConfiguration?.vsCodeLmModelSelector ? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}` diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index 2f62dd181d..c9fb7bfd42 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -39,6 +39,7 @@ describe("Model Validation Functions", () => { litellm: {}, ollama: {}, lmstudio: {}, + deepinfra: {}, "io-intelligence": {}, "vercel-ai-gateway": {}, } diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 1cbeba76d0..58cc8d38e8 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -47,6 +47,11 @@ function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): stri return i18next.t("settings:validation.apiKey") } break + case "deepinfra": + if (!apiConfiguration.deepInfraApiKey) { + return i18next.t("settings:validation.apiKey") + } + break case "litellm": if (!apiConfiguration.litellmApiKey) { return i18next.t("settings:validation.apiKey") @@ -193,6 +198,8 @@ function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: str return apiConfiguration.unboundModelId case "requesty": return apiConfiguration.requestyModelId + case "deepinfra": + return apiConfiguration.deepInfraModelId case "litellm": return apiConfiguration.litellmModelId case "openai": @@ -271,6 +278,9 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels case "requesty": modelId = apiConfiguration.requestyModelId break + case "deepinfra": + modelId = apiConfiguration.deepInfraModelId + break case "ollama": modelId = apiConfiguration.ollamaModelId break From 900d8b8733725f3ca7999e5ce28298095113be64 Mon Sep 17 00:00:00 2001 From: Thach Nguyen Date: Wed, 3 Sep 2025 23:37:21 -0700 Subject: [PATCH 2/3] fix tests --- src/core/webview/__tests__/ClineProvider.spec.ts | 4 ++++ .../webview/__tests__/webviewMessageHandler.spec.ts | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 400ce50468..8b04721638 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2680,6 +2680,7 @@ describe("ClineProvider - Router Models", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: mockModels, glama: mockModels, @@ -2719,6 +2720,7 @@ describe("ClineProvider - Router Models", () => { .mockResolvedValueOnce(mockModels) // glama success .mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail .mockResolvedValueOnce(mockModels) // vercel-ai-gateway success + .mockResolvedValueOnce(mockModels) // deepinfra success .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail await messageHandler({ type: "requestRouterModels" }) @@ -2727,6 +2729,7 @@ describe("ClineProvider - Router Models", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: {}, glama: mockModels, @@ -2838,6 +2841,7 @@ describe("ClineProvider - Router Models", () => { expect(mockPostMessage).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: mockModels, glama: mockModels, diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 06dbc03502..9241d93c68 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -174,6 +174,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { }) // Verify getModels was called for each provider + expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" }) @@ -189,6 +190,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: mockModels, glama: mockModels, @@ -277,6 +279,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: mockModels, glama: mockModels, @@ -306,6 +309,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockResolvedValueOnce(mockModels) // glama .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockResolvedValueOnce(mockModels) // vercel-ai-gateway + .mockResolvedValueOnce(mockModels) // deepinfra .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { @@ -316,6 +320,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "routerModels", routerModels: { + deepinfra: mockModels, openrouter: mockModels, requesty: {}, glama: mockModels, @@ -358,6 +363,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("Glama API error")) // glama .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockRejectedValueOnce(new Error("Vercel AI Gateway error")) // vercel-ai-gateway + .mockRejectedValueOnce(new Error("DeepInfra API error")) // deepinfra .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { @@ -393,6 +399,13 @@ describe("webviewMessageHandler - requestRouterModels", () => { values: { provider: "unbound" }, }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "DeepInfra API error", + values: { provider: "deepinfra" }, + }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "singleRouterModelFetchResponse", success: false, From 0724c8242cf8cbc70b1c69915ea15dc6bb1ec2d5 Mon Sep 17 00:00:00 2001 From: Thach Nguyen Date: Wed, 3 Sep 2025 23:52:39 -0700 Subject: [PATCH 3/3] Add changeset --- .changeset/petite-rats-admire.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changeset/petite-rats-admire.md diff --git a/.changeset/petite-rats-admire.md b/.changeset/petite-rats-admire.md new file mode 100644 index 0000000000..84568ed3cc --- /dev/null +++ b/.changeset/petite-rats-admire.md @@ -0,0 +1,6 @@ +--- +"roo-cline": minor +"@roo-code/types": patch +--- + +Added DeepInfra provider with dynamic model fetching and prompt caching