Skip to content

Commit 335bd70

Browse files
committed
feat: add DeepInfra provide with dynamic model fetching and prompt caching
1 parent 0be6743 commit 335bd70

File tree

22 files changed

+403
-0
lines changed

22 files changed

+403
-0
lines changed

packages/types/src/global-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ export const SECRET_STATE_KEYS = [
203203
"groqApiKey",
204204
"chutesApiKey",
205205
"litellmApiKey",
206+
"deepInfraApiKey",
206207
"codeIndexOpenAiKey",
207208
"codeIndexQdrantApiKey",
208209
// kilocode_change start

packages/types/src/provider-settings.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export const providerNames = [
5353
"fake-ai",
5454
"xai",
5555
"groq",
56+
"deepinfra",
5657
"chutes",
5758
"litellm",
5859
// kilocode_change start
@@ -248,6 +249,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
248249
deepSeekApiKey: z.string().optional(),
249250
})
250251

252+
const deepInfraSchema = apiModelIdProviderModelSchema.extend({
253+
deepInfraBaseUrl: z.string().optional(),
254+
deepInfraApiKey: z.string().optional(),
255+
deepInfraModelId: z.string().optional(),
256+
})
257+
251258
const doubaoSchema = apiModelIdProviderModelSchema.extend({
252259
doubaoBaseUrl: z.string().optional(),
253260
doubaoApiKey: z.string().optional(),
@@ -385,6 +392,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
385392
openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
386393
mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
387394
deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
395+
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
388396
doubaoSchema.merge(z.object({ apiProvider: z.literal("doubao") })),
389397
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
390398
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
@@ -434,6 +442,7 @@ export const providerSettingsSchema = z.object({
434442
...openAiNativeSchema.shape,
435443
...mistralSchema.shape,
436444
...deepSeekSchema.shape,
445+
...deepInfraSchema.shape,
437446
...doubaoSchema.shape,
438447
...moonshotSchema.shape,
439448
...unboundSchema.shape,
@@ -478,6 +487,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
478487
"litellmModelId",
479488
"huggingFaceModelId",
480489
"ioIntelligenceModelId",
490+
"deepInfraModelId",
481491
]
482492

483493
export const getModelId = (settings: ProviderSettings): string | undefined => {
@@ -593,6 +603,7 @@ export const MODELS_BY_PROVIDER: Record<
593603
openrouter: { id: "openrouter", label: "OpenRouter", models: [] },
594604
requesty: { id: "requesty", label: "Requesty", models: [] },
595605
unbound: { id: "unbound", label: "Unbound", models: [] },
606+
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
596607

597608
// kilocode_change start
598609
kilocode: { id: "kilocode", label: "Kilocode", models: [] },
@@ -608,6 +619,7 @@ export const dynamicProviders = [
608619
"openrouter",
609620
"requesty",
610621
"unbound",
622+
"deepinfra",
611623
// kilocode_change start
612624
"kilocode",
613625
"virtual-quota-fallback",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// Default fallback values for DeepInfra when model metadata is not yet loaded.
4+
export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo"
5+
6+
export const deepInfraDefaultModelInfo: ModelInfo = {
7+
maxTokens: 16384,
8+
contextWindow: 262144,
9+
supportsImages: false,
10+
supportsPromptCache: false,
11+
inputPrice: 0.3,
12+
outputPrice: 1.2,
13+
description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",
14+
}

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ export * from "./vertex.js"
2929
export * from "./vscode-llm.js"
3030
export * from "./xai.js"
3131
export * from "./zai.js"
32+
export * from "./deepinfra.js"

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import {
4343
FireworksHandler,
4444
RooHandler,
4545
FeatherlessHandler,
46+
DeepInfraHandler,
4647
} from "./providers"
4748
// kilocode_change start
4849
import { KilocodeOpenrouterHandler } from "./providers/kilocode-openrouter"
@@ -145,6 +146,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
145146
return new XAIHandler(options)
146147
case "groq":
147148
return new GroqHandler(options)
149+
case "deepinfra":
150+
return new DeepInfraHandler(options)
148151
case "huggingface":
149152
return new HuggingFaceHandler(options)
150153
case "chutes":

src/api/providers/deepinfra.ts

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import { Anthropic } from "@anthropic-ai/sdk" // for message param types
2+
import OpenAI from "openai"
3+
4+
import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"
5+
6+
import type { ApiHandlerOptions } from "../../shared/api"
7+
import { calculateApiCostOpenAI } from "../../shared/cost"
8+
9+
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
10+
import { convertToOpenAiMessages } from "../transform/openai-format"
11+
12+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
13+
import { RouterProvider } from "./router-provider"
14+
import { getModelParams } from "../transform/model-params"
15+
import { getModels } from "./fetchers/modelCache"
16+
17+
/**
18+
* DeepInfra provider handler (OpenAI compatible)
19+
*/
20+
export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler {
21+
constructor(options: ApiHandlerOptions) {
22+
super({
23+
options: {
24+
...options,
25+
openAiHeaders: {
26+
"X-Deepinfra-Source": "kilocode",
27+
"X-Deepinfra-Version": `2025-08-25`,
28+
},
29+
},
30+
name: "deepinfra",
31+
baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`,
32+
apiKey: options.deepInfraApiKey || "not-provided",
33+
modelId: options.deepInfraModelId,
34+
defaultModelId: deepInfraDefaultModelId,
35+
defaultModelInfo: deepInfraDefaultModelInfo,
36+
})
37+
}
38+
39+
public override async fetchModel() {
40+
this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
41+
return this.getModel()
42+
}
43+
44+
override getModel() {
45+
const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId
46+
const info = this.models[id] ?? deepInfraDefaultModelInfo
47+
48+
const params = getModelParams({
49+
format: "openai",
50+
modelId: id,
51+
model: info,
52+
settings: this.options,
53+
})
54+
55+
return { id, info, ...params }
56+
}
57+
58+
override async *createMessage(
59+
systemPrompt: string,
60+
messages: Anthropic.Messages.MessageParam[],
61+
_metadata?: ApiHandlerCreateMessageMetadata,
62+
): ApiStream {
63+
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
64+
65+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
66+
model: modelId,
67+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
68+
stream: true,
69+
stream_options: { include_usage: true },
70+
reasoning_effort,
71+
}
72+
73+
if (this.supportsTemperature(modelId)) {
74+
requestOptions.temperature = this.options.modelTemperature ?? 0
75+
}
76+
77+
// If includeMaxTokens is enabled, set a cap using model info
78+
if (this.options.includeMaxTokens === true && info.maxTokens) {
79+
// Prefer modern OpenAI param when available in SDK
80+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
81+
}
82+
83+
const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse()
84+
85+
let lastUsage: OpenAI.CompletionUsage | undefined
86+
for await (const chunk of stream) {
87+
const delta = chunk.choices[0]?.delta
88+
89+
if (delta?.content) {
90+
yield { type: "text", text: delta.content }
91+
}
92+
93+
if (delta && "reasoning_content" in delta && delta.reasoning_content) {
94+
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
95+
}
96+
97+
if (chunk.usage) {
98+
lastUsage = chunk.usage
99+
}
100+
}
101+
102+
if (lastUsage) {
103+
yield this.processUsageMetrics(lastUsage, info)
104+
}
105+
}
106+
107+
async completePrompt(prompt: string): Promise<string> {
108+
const { id: modelId, info } = await this.fetchModel()
109+
110+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
111+
model: modelId,
112+
messages: [{ role: "user", content: prompt }],
113+
}
114+
if (this.supportsTemperature(modelId)) {
115+
requestOptions.temperature = this.options.modelTemperature ?? 0
116+
}
117+
if (this.options.includeMaxTokens === true && info.maxTokens) {
118+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
119+
}
120+
121+
const resp = await this.client.chat.completions.create(requestOptions)
122+
return resp.choices[0]?.message?.content || ""
123+
}
124+
125+
protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
126+
const inputTokens = usage?.prompt_tokens || 0
127+
const outputTokens = usage?.completion_tokens || 0
128+
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
129+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
130+
131+
const totalCost = modelInfo
132+
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
133+
: 0
134+
135+
return {
136+
type: "usage",
137+
inputTokens,
138+
outputTokens,
139+
cacheWriteTokens: cacheWriteTokens || undefined,
140+
cacheReadTokens: cacheReadTokens || undefined,
141+
totalCost,
142+
}
143+
}
144+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import axios from "axios"
2+
import { z } from "zod"
3+
4+
import { type ModelInfo } from "@roo-code/types"
5+
6+
import { DEFAULT_HEADERS } from "../constants"
7+
8+
// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object.
9+
// Use only the supported fields and infer capabilities from tags.
10+
11+
const DeepInfraModelSchema = z.object({
12+
id: z.string(),
13+
object: z.literal("model"),
14+
owned_by: z.string().optional(),
15+
created: z.number().optional(),
16+
root: z.string().optional(),
17+
metadata: z
18+
.object({
19+
description: z.string().optional(),
20+
context_length: z.number().optional(),
21+
max_tokens: z.number().optional(),
22+
tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"]
23+
pricing: z
24+
.object({
25+
input_tokens: z.number().optional(),
26+
output_tokens: z.number().optional(),
27+
cache_read_tokens: z.number().optional(),
28+
})
29+
.optional(),
30+
})
31+
.optional(),
32+
})
33+
34+
const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) })
35+
36+
export async function getDeepInfraModels(
37+
apiKey?: string,
38+
baseUrl: string = "https://api.deepinfra.com/v1/openai",
39+
): Promise<Record<string, ModelInfo>> {
40+
const headers: Record<string, string> = { ...DEFAULT_HEADERS }
41+
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`
42+
43+
const url = `${baseUrl.replace(/\/$/, "")}/models`
44+
const models: Record<string, ModelInfo> = {}
45+
46+
const response = await axios.get(url, { headers })
47+
const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
48+
const data = parsed.success ? parsed.data.data : response.data?.data || []
49+
50+
for (const m of data as Array<z.infer<typeof DeepInfraModelSchema>>) {
51+
const meta = m.metadata || {}
52+
const tags = meta.tags || []
53+
54+
const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192
55+
const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2)
56+
57+
const info: ModelInfo = {
58+
maxTokens,
59+
contextWindow,
60+
supportsImages: tags.includes("vision"),
61+
supportsPromptCache: tags.includes("prompt_cache"),
62+
supportsReasoningEffort: tags.includes("reasoning_effort"),
63+
inputPrice: meta.pricing?.input_tokens,
64+
outputPrice: meta.pricing?.output_tokens,
65+
cacheReadsPrice: meta.pricing?.cache_read_tokens,
66+
description: meta.description,
67+
}
68+
69+
models[m.id] = info
70+
}
71+
72+
return models
73+
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { getKiloBaseUriFromToken } from "../../../shared/kilocode/token"
1919
import { getOllamaModels } from "./ollama"
2020
import { getLMStudioModels } from "./lmstudio"
2121
import { getIOIntelligenceModels } from "./io-intelligence"
22+
import { getDeepInfraModels } from "./deepinfra"
2223
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2324

2425
export /*kilocode_change*/ async function writeModels(router: RouterName, data: ModelRecord) {
@@ -78,6 +79,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
7879
// Type safety ensures apiKey and baseUrl are always provided for litellm
7980
models = await getLiteLLMModels(options.apiKey, options.baseUrl)
8081
break
82+
case "deepinfra":
83+
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
84+
break
8185
// kilocode_change start
8286
case "kilocode-openrouter":
8387
models = await getOpenRouterModels({

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ export { ZAiHandler } from "./zai"
3636
export { FireworksHandler } from "./fireworks"
3737
export { RooHandler } from "./roo"
3838
export { FeatherlessHandler } from "./featherless"
39+
export { DeepInfraHandler } from "./deepinfra"

src/core/webview/webviewMessageHandler.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ export const webviewMessageHandler = async (
571571
"kilocode-openrouter": {}, // kilocode_change
572572
ollama: {},
573573
lmstudio: {},
574+
deepinfra: {},
574575
}
575576

576577
const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
@@ -613,6 +614,7 @@ export const webviewMessageHandler = async (
613614
},
614615
},
615616
{ key: "ollama", options: { provider: "ollama", baseUrl: apiConfiguration.ollamaBaseUrl } },
617+
{ key: "deepinfra", options: { provider: "deepinfra", apiKey: apiConfiguration.deepInfraApiKey } },
616618
]
617619
// kilocode_change end
618620

0 commit comments

Comments
 (0)