Skip to content

Commit a670936

Browse files
Thachnhmtone
authored andcommitted
feat: Add DeepInfra as a model provider in Roo Code (RooCodeInc#7677)
1 parent 881516a commit a670936

File tree

23 files changed

+422
-0
lines changed

23 files changed

+422
-0
lines changed

.changeset/petite-rats-admire.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"roo-cline": minor
3+
"@roo-code/types": patch
4+
---
5+
6+
Added DeepInfra provider with dynamic model fetching and prompt caching

packages/types/src/global-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ export const SECRET_STATE_KEYS = [
192192
"groqApiKey",
193193
"chutesApiKey",
194194
"litellmApiKey",
195+
"deepInfraApiKey",
195196
"codeIndexOpenAiKey",
196197
"codeIndexQdrantApiKey",
197198
"codebaseIndexOpenAiCompatibleApiKey",

packages/types/src/provider-settings.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export const providerNames = [
4848
"mistral",
4949
"moonshot",
5050
"deepseek",
51+
"deepinfra",
5152
"doubao",
5253
"qwen-code",
5354
"unbound",
@@ -236,6 +237,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
236237
deepSeekApiKey: z.string().optional(),
237238
})
238239

240+
const deepInfraSchema = apiModelIdProviderModelSchema.extend({
241+
deepInfraBaseUrl: z.string().optional(),
242+
deepInfraApiKey: z.string().optional(),
243+
deepInfraModelId: z.string().optional(),
244+
})
245+
239246
const doubaoSchema = apiModelIdProviderModelSchema.extend({
240247
doubaoBaseUrl: z.string().optional(),
241248
doubaoApiKey: z.string().optional(),
@@ -349,6 +356,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
349356
openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
350357
mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
351358
deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
359+
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
352360
doubaoSchema.merge(z.object({ apiProvider: z.literal("doubao") })),
353361
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
354362
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
@@ -389,6 +397,7 @@ export const providerSettingsSchema = z.object({
389397
...openAiNativeSchema.shape,
390398
...mistralSchema.shape,
391399
...deepSeekSchema.shape,
400+
...deepInfraSchema.shape,
392401
...doubaoSchema.shape,
393402
...moonshotSchema.shape,
394403
...unboundSchema.shape,
@@ -438,6 +447,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
438447
"huggingFaceModelId",
439448
"ioIntelligenceModelId",
440449
"vercelAiGatewayModelId",
450+
"deepInfraModelId",
441451
]
442452

443453
export const getModelId = (settings: ProviderSettings): string | undefined => {
@@ -559,6 +569,7 @@ export const MODELS_BY_PROVIDER: Record<
559569
openrouter: { id: "openrouter", label: "OpenRouter", models: [] },
560570
requesty: { id: "requesty", label: "Requesty", models: [] },
561571
unbound: { id: "unbound", label: "Unbound", models: [] },
572+
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
562573
"vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] },
563574
}
564575

@@ -569,6 +580,7 @@ export const dynamicProviders = [
569580
"openrouter",
570581
"requesty",
571582
"unbound",
583+
"deepinfra",
572584
"vercel-ai-gateway",
573585
] as const satisfies readonly ProviderName[]
574586

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// Default fallback values for DeepInfra when model metadata is not yet loaded.
4+
export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo"
5+
6+
export const deepInfraDefaultModelInfo: ModelInfo = {
7+
maxTokens: 16384,
8+
contextWindow: 262144,
9+
supportsImages: false,
10+
supportsPromptCache: false,
11+
inputPrice: 0.3,
12+
outputPrice: 1.2,
13+
description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",
14+
}

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ export * from "./vscode-llm.js"
2929
export * from "./xai.js"
3030
export * from "./vercel-ai-gateway.js"
3131
export * from "./zai.js"
32+
export * from "./deepinfra.js"

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import {
3939
RooHandler,
4040
FeatherlessHandler,
4141
VercelAiGatewayHandler,
42+
DeepInfraHandler,
4243
} from "./providers"
4344
import { NativeOllamaHandler } from "./providers/native-ollama"
4445

@@ -138,6 +139,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
138139
return new XAIHandler(options)
139140
case "groq":
140141
return new GroqHandler(options)
142+
case "deepinfra":
143+
return new DeepInfraHandler(options)
141144
case "huggingface":
142145
return new HuggingFaceHandler(options)
143146
case "chutes":

src/api/providers/deepinfra.ts

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import OpenAI from "openai"
3+
4+
import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"
5+
6+
import type { ApiHandlerOptions } from "../../shared/api"
7+
import { calculateApiCostOpenAI } from "../../shared/cost"
8+
9+
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
10+
import { convertToOpenAiMessages } from "../transform/openai-format"
11+
12+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
13+
import { RouterProvider } from "./router-provider"
14+
import { getModelParams } from "../transform/model-params"
15+
import { getModels } from "./fetchers/modelCache"
16+
17+
export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler {
18+
constructor(options: ApiHandlerOptions) {
19+
super({
20+
options: {
21+
...options,
22+
openAiHeaders: {
23+
"X-Deepinfra-Source": "roo-code",
24+
"X-Deepinfra-Version": `2025-08-25`,
25+
},
26+
},
27+
name: "deepinfra",
28+
baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`,
29+
apiKey: options.deepInfraApiKey || "not-provided",
30+
modelId: options.deepInfraModelId,
31+
defaultModelId: deepInfraDefaultModelId,
32+
defaultModelInfo: deepInfraDefaultModelInfo,
33+
})
34+
}
35+
36+
public override async fetchModel() {
37+
this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
38+
return this.getModel()
39+
}
40+
41+
override getModel() {
42+
const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId
43+
const info = this.models[id] ?? deepInfraDefaultModelInfo
44+
45+
const params = getModelParams({
46+
format: "openai",
47+
modelId: id,
48+
model: info,
49+
settings: this.options,
50+
})
51+
52+
return { id, info, ...params }
53+
}
54+
55+
override async *createMessage(
56+
systemPrompt: string,
57+
messages: Anthropic.Messages.MessageParam[],
58+
_metadata?: ApiHandlerCreateMessageMetadata,
59+
): ApiStream {
60+
// Ensure we have up-to-date model metadata
61+
await this.fetchModel()
62+
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
63+
let prompt_cache_key = undefined
64+
if (info.supportsPromptCache && _metadata?.taskId) {
65+
prompt_cache_key = _metadata.taskId
66+
}
67+
68+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
69+
model: modelId,
70+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
71+
stream: true,
72+
stream_options: { include_usage: true },
73+
reasoning_effort,
74+
prompt_cache_key,
75+
} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
76+
77+
if (this.supportsTemperature(modelId)) {
78+
requestOptions.temperature = this.options.modelTemperature ?? 0
79+
}
80+
81+
if (this.options.includeMaxTokens === true && info.maxTokens) {
82+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
83+
}
84+
85+
const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse()
86+
87+
let lastUsage: OpenAI.CompletionUsage | undefined
88+
for await (const chunk of stream) {
89+
const delta = chunk.choices[0]?.delta
90+
91+
if (delta?.content) {
92+
yield { type: "text", text: delta.content }
93+
}
94+
95+
if (delta && "reasoning_content" in delta && delta.reasoning_content) {
96+
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
97+
}
98+
99+
if (chunk.usage) {
100+
lastUsage = chunk.usage
101+
}
102+
}
103+
104+
if (lastUsage) {
105+
yield this.processUsageMetrics(lastUsage, info)
106+
}
107+
}
108+
109+
async completePrompt(prompt: string): Promise<string> {
110+
await this.fetchModel()
111+
const { id: modelId, info } = this.getModel()
112+
113+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
114+
model: modelId,
115+
messages: [{ role: "user", content: prompt }],
116+
}
117+
if (this.supportsTemperature(modelId)) {
118+
requestOptions.temperature = this.options.modelTemperature ?? 0
119+
}
120+
if (this.options.includeMaxTokens === true && info.maxTokens) {
121+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
122+
}
123+
124+
const resp = await this.client.chat.completions.create(requestOptions)
125+
return resp.choices[0]?.message?.content || ""
126+
}
127+
128+
protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
129+
const inputTokens = usage?.prompt_tokens || 0
130+
const outputTokens = usage?.completion_tokens || 0
131+
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
132+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
133+
134+
const totalCost = modelInfo
135+
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
136+
: 0
137+
138+
return {
139+
type: "usage",
140+
inputTokens,
141+
outputTokens,
142+
cacheWriteTokens: cacheWriteTokens || undefined,
143+
cacheReadTokens: cacheReadTokens || undefined,
144+
totalCost,
145+
}
146+
}
147+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import axios from "axios"
2+
import { z } from "zod"
3+
4+
import { type ModelInfo } from "@roo-code/types"
5+
6+
import { DEFAULT_HEADERS } from "../constants"
7+
8+
// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object.
9+
10+
const DeepInfraModelSchema = z.object({
11+
id: z.string(),
12+
object: z.literal("model").optional(),
13+
owned_by: z.string().optional(),
14+
created: z.number().optional(),
15+
root: z.string().optional(),
16+
metadata: z
17+
.object({
18+
description: z.string().optional(),
19+
context_length: z.number().optional(),
20+
max_tokens: z.number().optional(),
21+
tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"]
22+
pricing: z
23+
.object({
24+
input_tokens: z.number().optional(),
25+
output_tokens: z.number().optional(),
26+
cache_read_tokens: z.number().optional(),
27+
})
28+
.optional(),
29+
})
30+
.optional(),
31+
})
32+
33+
const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) })
34+
35+
export async function getDeepInfraModels(
36+
apiKey?: string,
37+
baseUrl: string = "https://api.deepinfra.com/v1/openai",
38+
): Promise<Record<string, ModelInfo>> {
39+
const headers: Record<string, string> = { ...DEFAULT_HEADERS }
40+
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`
41+
42+
const url = `${baseUrl.replace(/\/$/, "")}/models`
43+
const models: Record<string, ModelInfo> = {}
44+
45+
const response = await axios.get(url, { headers })
46+
const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
47+
const data = parsed.success ? parsed.data.data : response.data?.data || []
48+
49+
for (const m of data as Array<z.infer<typeof DeepInfraModelSchema>>) {
50+
const meta = m.metadata || {}
51+
const tags = meta.tags || []
52+
53+
const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192
54+
const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2)
55+
56+
const info: ModelInfo = {
57+
maxTokens,
58+
contextWindow,
59+
supportsImages: tags.includes("vision"),
60+
supportsPromptCache: tags.includes("prompt_cache"),
61+
inputPrice: meta.pricing?.input_tokens,
62+
outputPrice: meta.pricing?.output_tokens,
63+
cacheReadsPrice: meta.pricing?.cache_read_tokens,
64+
description: meta.description,
65+
}
66+
67+
models[m.id] = info
68+
}
69+
70+
return models
71+
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api"
1919
import { getOllamaModels } from "./ollama"
2020
import { getLMStudioModels } from "./lmstudio"
2121
import { getIOIntelligenceModels } from "./io-intelligence"
22+
import { getDeepInfraModels } from "./deepinfra"
2223
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2324

2425
async function writeModels(router: RouterName, data: ModelRecord) {
@@ -79,6 +80,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
7980
case "lmstudio":
8081
models = await getLMStudioModels(options.baseUrl)
8182
break
83+
case "deepinfra":
84+
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
85+
break
8286
case "io-intelligence":
8387
models = await getIOIntelligenceModels(options.apiKey)
8488
break

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ export { FireworksHandler } from "./fireworks"
3333
export { RooHandler } from "./roo"
3434
export { FeatherlessHandler } from "./featherless"
3535
export { VercelAiGatewayHandler } from "./vercel-ai-gateway"
36+
export { DeepInfraHandler } from "./deepinfra"

0 commit comments

Comments
 (0)