From 3c24584a99a8cf26b1cce6665a2710fa10f87879 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Thu, 4 Sep 2025 07:30:24 +0000 Subject: [PATCH] feat: add DeepInfra as a model provider - Add DeepInfra types and model definitions - Implement DeepInfra provider handler with OpenAI-compatible API - Add dynamic model fetching from DeepInfra API - Support prompt caching for reduced costs - Update UI components to support DeepInfra selection - Add DeepInfra to router configuration Implements #7661 --- packages/types/src/provider-settings.ts | 11 ++ packages/types/src/providers/deepinfra.ts | 12 ++ packages/types/src/providers/index.ts | 1 + src/api/index.ts | 3 + src/api/providers/deepinfra.ts | 102 ++++++++++++ src/api/providers/fetchers/deepinfra.ts | 150 ++++++++++++++++++ src/api/providers/fetchers/modelCache.ts | 5 + src/api/providers/index.ts | 1 + src/shared/api.ts | 2 + .../components/ui/hooks/useSelectedModel.ts | 8 +- .../src/utils/__tests__/validate.test.ts | 1 + 11 files changed, 295 insertions(+), 1 deletion(-) create mode 100644 packages/types/src/providers/deepinfra.ts create mode 100644 src/api/providers/deepinfra.ts create mode 100644 src/api/providers/fetchers/deepinfra.ts diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 090dfe6693..c1b2c83168 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -34,6 +34,7 @@ import { export const providerNames = [ "anthropic", "claude-code", + "deepinfra", "glama", "openrouter", "bedrock", @@ -294,6 +295,11 @@ const cerebrasSchema = apiModelIdProviderModelSchema.extend({ cerebrasApiKey: z.string().optional(), }) +const deepInfraSchema = baseProviderSettingsSchema.extend({ + deepInfraApiKey: z.string().optional(), + deepInfraModelId: z.string().optional(), +}) + const sambaNovaSchema = apiModelIdProviderModelSchema.extend({ sambaNovaApiKey: z.string().optional(), }) @@ -336,6 +342,7 @@ const defaultSchema = z.object({ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [ anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })), claudeCodeSchema.merge(z.object({ apiProvider: z.literal("claude-code") })), + deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })), glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })), openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })), bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), @@ -376,6 +383,7 @@ export const providerSettingsSchema = z.object({ apiProvider: providerNamesSchema.optional(), ...anthropicSchema.shape, ...claudeCodeSchema.shape, + ...deepInfraSchema.shape, ...glamaSchema.shape, ...openRouterSchema.shape, ...bedrockSchema.shape, @@ -426,6 +434,7 @@ export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options export const MODEL_ID_KEYS: Partial[] = [ "apiModelId", + "deepInfraModelId", "glamaModelId", "openRouterModelId", "openAiModelId", @@ -489,6 +498,7 @@ export const MODELS_BY_PROVIDER: Record< label: "Chutes AI", models: Object.keys(chutesModels), }, + deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] }, "claude-code": { id: "claude-code", label: "Claude Code", models: Object.keys(claudeCodeModels) }, deepseek: { id: "deepseek", @@ -563,6 +573,7 @@ export const MODELS_BY_PROVIDER: Record< } export const dynamicProviders = [ + "deepinfra", "glama", "huggingface", "litellm", diff --git a/packages/types/src/providers/deepinfra.ts b/packages/types/src/providers/deepinfra.ts new file mode 100644 index 0000000000..2f952955b8 --- /dev/null +++ b/packages/types/src/providers/deepinfra.ts @@ -0,0 +1,12 @@ +import type { ModelInfo } from "../model.js" + +// DeepInfra models are fetched dynamically from their API +// This type represents the model IDs that will be available +export type DeepInfraModelId = string + +// Default model to use when none is specified +export const deepInfraDefaultModelId: DeepInfraModelId = "meta-llama/Llama-3.3-70B-Instruct" + +// DeepInfra models will be fetched dynamically, so we provide an empty object +// The actual models will be populated at runtime via the API +export const deepInfraModels = {} as const satisfies Record diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 97fa10ca82..085b461f01 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -3,6 +3,7 @@ export * from "./bedrock.js" export * from "./cerebras.js" export * from "./chutes.js" export * from "./claude-code.js" +export * from "./deepinfra.js" export * from "./deepseek.js" export * from "./doubao.js" export * from "./featherless.js" diff --git a/src/api/index.ts b/src/api/index.ts index b50afbb023..b639246991 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -9,6 +9,7 @@ import { AnthropicHandler, AwsBedrockHandler, CerebrasHandler, + DeepInfraHandler, OpenRouterHandler, VertexHandler, AnthropicVertexHandler, @@ -114,6 +115,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new GeminiHandler(options) case "openai-native": return new OpenAiNativeHandler(options) + case "deepinfra": + return new DeepInfraHandler(options) case "deepseek": return new DeepSeekHandler(options) case "doubao": diff --git a/src/api/providers/deepinfra.ts b/src/api/providers/deepinfra.ts new file mode 100644 index 0000000000..e83a04f377 --- /dev/null +++ b/src/api/providers/deepinfra.ts @@ -0,0 +1,102 @@ +import { type DeepInfraModelId, deepInfraDefaultModelId } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import type { ApiHandlerOptions } from "../../shared/api" +import type { ApiHandlerCreateMessageMetadata } from "../index" +import type { ModelInfo } from "@roo-code/types" +import { ApiStream } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { calculateApiCostOpenAI } from "../../shared/cost" + +import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" + +// Enhanced usage interface to support DeepInfra's cached token fields +interface DeepInfraUsage extends OpenAI.CompletionUsage { + prompt_tokens_details?: { + cached_tokens?: number + } +} + +export class DeepInfraHandler extends BaseOpenAiCompatibleProvider { + constructor(options: ApiHandlerOptions) { + // Initialize with empty models, will be populated dynamically + super({ + ...options, + providerName: "DeepInfra", + baseURL: "https://api.deepinfra.com/v1/openai", + apiKey: options.deepInfraApiKey, + defaultProviderModelId: deepInfraDefaultModelId, + providerModels: {}, + defaultTemperature: 0.7, + }) + } + + override getModel() { + const modelId = this.options.deepInfraModelId || deepInfraDefaultModelId + + // For DeepInfra, we use a default model configuration + // The actual model info will be fetched dynamically via the fetcher + const defaultModelInfo: ModelInfo = { + maxTokens: 4096, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.15, + outputPrice: 0.6, + cacheReadsPrice: 0.075, // 50% discount for cached tokens + description: "DeepInfra model", + } + + return { id: modelId, info: defaultModelInfo } + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const stream = await this.createStream(systemPrompt, messages, metadata) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (chunk.usage) { + yield* this.yieldUsage(chunk.usage as DeepInfraUsage) + } + } + } + + private async *yieldUsage(usage: DeepInfraUsage | undefined): ApiStream { + const { info } = this.getModel() + const inputTokens = usage?.prompt_tokens || 0 + const outputTokens = usage?.completion_tokens || 0 + + const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 + + // DeepInfra does not track cache writes separately + const cacheWriteTokens = 0 + + // Calculate cost using OpenAI-compatible cost calculation + const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + + // Calculate non-cached input tokens for proper reporting + const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens) + + yield { + type: "usage", + inputTokens: nonCachedInputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + totalCost, + } + } +} diff --git a/src/api/providers/fetchers/deepinfra.ts b/src/api/providers/fetchers/deepinfra.ts new file mode 100644 index 0000000000..4f2fa3581a --- /dev/null +++ b/src/api/providers/fetchers/deepinfra.ts @@ -0,0 +1,150 @@ +import axios from "axios" +import { z } from "zod" + +import type { ModelInfo } from "@roo-code/types" + +import { parseApiPrice } from "../../../shared/cost" + +/** + * DeepInfra Model Schema + */ +const deepInfraModelSchema = z.object({ + model_name: z.string(), + type: z.string().optional(), + max_tokens: z.number().optional(), + context_length: z.number().optional(), + pricing: z + .object({ + input: z.number().optional(), + output: z.number().optional(), + cached_input: z.number().optional(), + }) + .optional(), + description: z.string().optional(), + capabilities: z.array(z.string()).optional(), +}) + +type DeepInfraModel = z.infer + +/** + * DeepInfra Models Response Schema + */ +const deepInfraModelsResponseSchema = z.array(deepInfraModelSchema) + +type DeepInfraModelsResponse = z.infer + +/** + * Fetch models from DeepInfra API + */ +export async function getDeepInfraModels(apiKey?: string): Promise> { + const models: Record = {} + const baseURL = "https://api.deepinfra.com/v1/openai" + + try { + // DeepInfra requires authentication to fetch models + if (!apiKey) { + console.log("DeepInfra API key not provided, returning empty models") + return models + } + + const response = await axios.get(`${baseURL}/models`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }) + + const result = deepInfraModelsResponseSchema.safeParse(response.data) + const data = result.success ? result.data : response.data + + if (!result.success) { + console.error("DeepInfra models response is invalid", result.error.format()) + } + + // Process each model from the response + for (const model of data) { + // Skip non-text models + if (model.type && !["text", "chat", "instruct"].includes(model.type)) { + continue + } + + const modelInfo: ModelInfo = { + maxTokens: model.max_tokens || 4096, + contextWindow: model.context_length || 32768, + supportsImages: model.capabilities?.includes("vision") || false, + supportsPromptCache: true, // DeepInfra supports prompt caching + inputPrice: model.pricing?.input ? model.pricing.input / 1000000 : 0.15, // Convert from per million to per token + outputPrice: model.pricing?.output ? model.pricing.output / 1000000 : 0.6, + cacheReadsPrice: model.pricing?.cached_input ? model.pricing.cached_input / 1000000 : undefined, + description: model.description, + } + + models[model.model_name] = modelInfo + } + + // If the API doesn't return models, provide some default popular models + if (Object.keys(models).length === 0) { + console.log("No models returned from DeepInfra API, using default models") + + // Default popular models on DeepInfra + models["meta-llama/Llama-3.3-70B-Instruct"] = { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.35 / 1000000, + outputPrice: 0.4 / 1000000, + cacheReadsPrice: 0.175 / 1000000, + description: "Meta Llama 3.3 70B Instruct model", + } + + models["meta-llama/Llama-3.1-8B-Instruct"] = { + maxTokens: 4096, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.06 / 1000000, + outputPrice: 0.06 / 1000000, + cacheReadsPrice: 0.03 / 1000000, + description: "Meta Llama 3.1 8B Instruct model", + } + + models["Qwen/Qwen2.5-72B-Instruct"] = { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.35 / 1000000, + outputPrice: 0.4 / 1000000, + cacheReadsPrice: 0.175 / 1000000, + description: "Qwen 2.5 72B Instruct model", + } + + models["mistralai/Mixtral-8x7B-Instruct-v0.1"] = { + maxTokens: 4096, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.24 / 1000000, + outputPrice: 0.24 / 1000000, + cacheReadsPrice: 0.12 / 1000000, + description: "Mistral Mixtral 8x7B Instruct model", + } + } + } catch (error) { + console.error(`Error fetching DeepInfra models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + + // Return default models on error + models["meta-llama/Llama-3.3-70B-Instruct"] = { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0.35 / 1000000, + outputPrice: 0.4 / 1000000, + cacheReadsPrice: 0.175 / 1000000, + description: "Meta Llama 3.3 70B Instruct model", + } + } + + return models +} diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 0005e8205f..b84a9f3793 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api" import { getOllamaModels } from "./ollama" import { getLMStudioModels } from "./lmstudio" import { getIOIntelligenceModels } from "./io-intelligence" +import { getDeepInfraModels } from "./deepinfra" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -55,6 +56,10 @@ export const getModels = async (options: GetModelsOptions): Promise try { switch (provider) { + case "deepinfra": + // DeepInfra models endpoint requires an API key + models = await getDeepInfraModels(options.apiKey) + break case "openrouter": models = await getOpenRouterModels() break diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index c3786c5f56..834115b166 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -4,6 +4,7 @@ export { AwsBedrockHandler } from "./bedrock" export { CerebrasHandler } from "./cerebras" export { ChutesHandler } from "./chutes" export { ClaudeCodeHandler } from "./claude-code" +export { DeepInfraHandler } from "./deepinfra" export { DeepSeekHandler } from "./deepseek" export { DoubaoHandler } from "./doubao" export { MoonshotHandler } from "./moonshot" diff --git a/src/shared/api.ts b/src/shared/api.ts index 30dfd7393b..29321e9755 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -19,6 +19,7 @@ export type ApiHandlerOptions = Omit & { // RouterName const routerNames = [ + "deepinfra", "openrouter", "requesty", "glama", @@ -144,6 +145,7 @@ export const getModelMaxOutputTokens = ({ // GetModelsOptions export type GetModelsOptions = + | { provider: "deepinfra"; apiKey?: string } | { provider: "openrouter" } | { provider: "glama" } | { provider: "requesty"; apiKey?: string; baseUrl?: string } diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index e9470e0902..ba1e7f9f7d 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -8,6 +8,7 @@ import { bedrockModels, cerebrasDefaultModelId, cerebrasModels, + deepInfraDefaultModelId, deepSeekDefaultModelId, deepSeekModels, moonshotDefaultModelId, @@ -119,6 +120,11 @@ function getSelectedModel({ // users from seeing the default model if their selection is invalid // this gives a better UX than showing the default model switch (provider) { + case "deepinfra": { + const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId + const info = routerModels.deepinfra?.[id] + return { id, info } + } case "openrouter": { const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId let info = routerModels.openrouter[id] @@ -339,7 +345,7 @@ function getSelectedModel({ // case "human-relay": // case "fake-ai": default: { - provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai" + provider satisfies "anthropic" | "gemini-cli" | "human-relay" | "fake-ai" const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId const baseInfo = anthropicModels[id as keyof typeof anthropicModels] diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index 2f62dd181d..aa5a7b8387 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -6,6 +6,7 @@ import { getModelValidationError, validateApiConfigurationExcludingModelErrors } describe("Model Validation Functions", () => { const mockRouterModels: RouterModels = { + deepinfra: {}, openrouter: { "valid-model": { maxTokens: 8192,