Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/petite-rats-admire.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"roo-cline": minor
"@roo-code/types": patch
---

Added DeepInfra provider with dynamic model fetching and prompt caching
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ export const SECRET_STATE_KEYS = [
"groqApiKey",
"chutesApiKey",
"litellmApiKey",
"deepInfraApiKey",
"codeIndexOpenAiKey",
"codeIndexQdrantApiKey",
"codebaseIndexOpenAiCompatibleApiKey",
Expand Down
12 changes: 12 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export const providerNames = [
"mistral",
"moonshot",
"deepseek",
"deepinfra",
"doubao",
"qwen-code",
"unbound",
Expand Down Expand Up @@ -236,6 +237,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
deepSeekApiKey: z.string().optional(),
})

const deepInfraSchema = apiModelIdProviderModelSchema.extend({
deepInfraBaseUrl: z.string().optional(),
deepInfraApiKey: z.string().optional(),
deepInfraModelId: z.string().optional(),
})

const doubaoSchema = apiModelIdProviderModelSchema.extend({
doubaoBaseUrl: z.string().optional(),
doubaoApiKey: z.string().optional(),
Expand Down Expand Up @@ -349,6 +356,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
doubaoSchema.merge(z.object({ apiProvider: z.literal("doubao") })),
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
Expand Down Expand Up @@ -389,6 +397,7 @@ export const providerSettingsSchema = z.object({
...openAiNativeSchema.shape,
...mistralSchema.shape,
...deepSeekSchema.shape,
...deepInfraSchema.shape,
...doubaoSchema.shape,
...moonshotSchema.shape,
...unboundSchema.shape,
Expand Down Expand Up @@ -438,6 +447,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
"huggingFaceModelId",
"ioIntelligenceModelId",
"vercelAiGatewayModelId",
"deepInfraModelId",
]

export const getModelId = (settings: ProviderSettings): string | undefined => {
Expand Down Expand Up @@ -559,6 +569,7 @@ export const MODELS_BY_PROVIDER: Record<
openrouter: { id: "openrouter", label: "OpenRouter", models: [] },
requesty: { id: "requesty", label: "Requesty", models: [] },
unbound: { id: "unbound", label: "Unbound", models: [] },
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
"vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] },
}

Expand All @@ -569,6 +580,7 @@ export const dynamicProviders = [
"openrouter",
"requesty",
"unbound",
"deepinfra",
"vercel-ai-gateway",
] as const satisfies readonly ProviderName[]

Expand Down
14 changes: 14 additions & 0 deletions packages/types/src/providers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { ModelInfo } from "../model.js"

// Default fallback values for DeepInfra when model metadata is not yet loaded.
export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding JSDoc comments to document the DeepInfra-specific features, especially the prompt caching support. This would help other developers understand the unique capabilities of this provider.


export const deepInfraDefaultModelInfo: ModelInfo = {
maxTokens: 16384,
contextWindow: 262144,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.3,
outputPrice: 1.2,
description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",
}
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ export * from "./vscode-llm.js"
export * from "./xai.js"
export * from "./vercel-ai-gateway.js"
export * from "./zai.js"
export * from "./deepinfra.js"
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {
RooHandler,
FeatherlessHandler,
VercelAiGatewayHandler,
DeepInfraHandler,
} from "./providers"
import { NativeOllamaHandler } from "./providers/native-ollama"

Expand Down Expand Up @@ -138,6 +139,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new XAIHandler(options)
case "groq":
return new GroqHandler(options)
case "deepinfra":
return new DeepInfraHandler(options)
case "huggingface":
return new HuggingFaceHandler(options)
case "chutes":
Expand Down
147 changes: 147 additions & 0 deletions src/api/providers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"
import { calculateApiCostOpenAI } from "../../shared/cost"

import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { RouterProvider } from "./router-provider"
import { getModelParams } from "../transform/model-params"
import { getModels } from "./fetchers/modelCache"

export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage for this new provider implementation. Could you add unit tests similar to other providers in the codebase? This would help ensure the DeepInfra integration works correctly and prevent regressions.

constructor(options: ApiHandlerOptions) {
super({
options: {
...options,
openAiHeaders: {
"X-Deepinfra-Source": "roo-code",
"X-Deepinfra-Version": `2025-08-25`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? The version date appears to be in the future (August 2025). Should this be '2024-08-25' or another appropriate date?

Suggested change
"X-Deepinfra-Version": `2025-08-25`,
"X-Deepinfra-Version": `2024-08-25`,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We live in the future!

},
},
name: "deepinfra",
baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`,
apiKey: options.deepInfraApiKey || "not-provided",
modelId: options.deepInfraModelId,
defaultModelId: deepInfraDefaultModelId,
defaultModelInfo: deepInfraDefaultModelInfo,
})
}

public override async fetchModel() {
this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
return this.getModel()
}

override getModel() {
const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId
const info = this.models[id] ?? deepInfraDefaultModelInfo

const params = getModelParams({
format: "openai",
modelId: id,
model: info,
settings: this.options,
})

return { id, info, ...params }
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
_metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
// Ensure we have up-to-date model metadata
await this.fetchModel()
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a duplicate call to fetchModel() here. Line 61 already calls it, so line 62 is redundant and could cause unnecessary API calls. Consider removing this duplicate.

let prompt_cache_key = undefined
if (info.supportsPromptCache && _metadata?.taskId) {
prompt_cache_key = _metadata.taskId
}

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
reasoning_effort,
prompt_cache_key,
} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming

if (this.supportsTemperature(modelId)) {
requestOptions.temperature = this.options.modelTemperature ?? 0
}

if (this.options.includeMaxTokens === true && info.maxTokens) {
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
}

const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse()

let lastUsage: OpenAI.CompletionUsage | undefined
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield { type: "text", text: delta.content }
}

if (delta && "reasoning_content" in delta && delta.reasoning_content) {
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
}

if (chunk.usage) {
lastUsage = chunk.usage
}
}

if (lastUsage) {
yield this.processUsageMetrics(lastUsage, info)
}
}

async completePrompt(prompt: string): Promise<string> {
await this.fetchModel()
const { id: modelId, info } = this.getModel()

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [{ role: "user", content: prompt }],
}
if (this.supportsTemperature(modelId)) {
requestOptions.temperature = this.options.modelTemperature ?? 0
}
if (this.options.includeMaxTokens === true && info.maxTokens) {
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
}

const resp = await this.client.chat.completions.create(requestOptions)
return resp.choices[0]?.message?.content || ""
}

protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

const totalCost = modelInfo
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
: 0

return {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens: cacheWriteTokens || undefined,
cacheReadTokens: cacheReadTokens || undefined,
totalCost,
}
}
}
71 changes: 71 additions & 0 deletions src/api/providers/fetchers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import axios from "axios"
import { z } from "zod"

import { type ModelInfo } from "@roo-code/types"

import { DEFAULT_HEADERS } from "../constants"

// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object.

const DeepInfraModelSchema = z.object({
id: z.string(),
object: z.literal("model").optional(),
owned_by: z.string().optional(),
created: z.number().optional(),
root: z.string().optional(),
metadata: z
.object({
description: z.string().optional(),
context_length: z.number().optional(),
max_tokens: z.number().optional(),
tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"]
pricing: z
.object({
input_tokens: z.number().optional(),
output_tokens: z.number().optional(),
cache_read_tokens: z.number().optional(),
})
.optional(),
})
.optional(),
})

const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) })

export async function getDeepInfraModels(
apiKey?: string,
baseUrl: string = "https://api.deepinfra.com/v1/openai",
): Promise<Record<string, ModelInfo>> {
const headers: Record<string, string> = { ...DEFAULT_HEADERS }
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`

const url = `${baseUrl.replace(/\/$/, "")}/models`
const models: Record<string, ModelInfo> = {}

const response = await axios.get(url, { headers })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more specific error handling here. For example, distinguishing between rate limiting (429), authentication failures (401/403), and other errors would provide better user feedback. Could we enhance this to match the error handling patterns used in other fetchers?

const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
const data = parsed.success ? parsed.data.data : response.data?.data || []

for (const m of data as Array<z.infer<typeof DeepInfraModelSchema>>) {
const meta = m.metadata || {}
const tags = meta.tags || []

const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192
const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2)

const info: ModelInfo = {
maxTokens,
contextWindow,
supportsImages: tags.includes("vision"),
supportsPromptCache: tags.includes("prompt_cache"),
inputPrice: meta.pricing?.input_tokens,
outputPrice: meta.pricing?.output_tokens,
cacheReadsPrice: meta.pricing?.cache_read_tokens,
description: meta.description,
}

models[m.id] = info
}

return models
}
4 changes: 4 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api"
import { getOllamaModels } from "./ollama"
import { getLMStudioModels } from "./lmstudio"
import { getIOIntelligenceModels } from "./io-intelligence"
import { getDeepInfraModels } from "./deepinfra"
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

async function writeModels(router: RouterName, data: ModelRecord) {
Expand Down Expand Up @@ -79,6 +80,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
break
case "deepinfra":
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
break
case "io-intelligence":
models = await getIOIntelligenceModels(options.apiKey)
break
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ export { FireworksHandler } from "./fireworks"
export { RooHandler } from "./roo"
export { FeatherlessHandler } from "./featherless"
export { VercelAiGatewayHandler } from "./vercel-ai-gateway"
export { DeepInfraHandler } from "./deepinfra"
4 changes: 4 additions & 0 deletions src/core/webview/__tests__/ClineProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,7 @@ describe("ClineProvider - Router Models", () => {
expect(mockPostMessage).toHaveBeenCalledWith({
type: "routerModels",
routerModels: {
deepinfra: mockModels,
openrouter: mockModels,
requesty: mockModels,
glama: mockModels,
Expand Down Expand Up @@ -2719,6 +2720,7 @@ describe("ClineProvider - Router Models", () => {
.mockResolvedValueOnce(mockModels) // glama success
.mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail
.mockResolvedValueOnce(mockModels) // vercel-ai-gateway success
.mockResolvedValueOnce(mockModels) // deepinfra success
.mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail

await messageHandler({ type: "requestRouterModels" })
Expand All @@ -2727,6 +2729,7 @@ describe("ClineProvider - Router Models", () => {
expect(mockPostMessage).toHaveBeenCalledWith({
type: "routerModels",
routerModels: {
deepinfra: mockModels,
openrouter: mockModels,
requesty: {},
glama: mockModels,
Expand Down Expand Up @@ -2838,6 +2841,7 @@ describe("ClineProvider - Router Models", () => {
expect(mockPostMessage).toHaveBeenCalledWith({
type: "routerModels",
routerModels: {
deepinfra: mockModels,
openrouter: mockModels,
requesty: mockModels,
glama: mockModels,
Expand Down
Loading
Loading