Skip to content

Commit 6e52408

Browse files
authored
Merge pull request #2199 from Kilo-Org/feat/deepinfra
feat(provider): add DeepInfra with dynamic model fetching & prompt-caching
2 parents 5d6bcc4 + 26dbac0 commit 6e52408

File tree

26 files changed

+442
-3
lines changed

26 files changed

+442
-3
lines changed

.changeset/pretty-hornets-brake.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"kilo-code": minor
3+
---
4+
5+
Thanks @Thachnh! - Added DeepInfra provider with dynamic model fetching and prompt caching

packages/types/src/global-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ export const SECRET_STATE_KEYS = [
207207
"codeIndexQdrantApiKey",
208208
// kilocode_change start
209209
"kilocodeToken",
210+
"deepInfraApiKey",
210211
// kilocode_change end
211212
"codebaseIndexOpenAiCompatibleApiKey",
212213
"codebaseIndexGeminiApiKey",

packages/types/src/provider-settings.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export const providerNames = [
5757
"litellm",
5858
// kilocode_change start
5959
"kilocode",
60+
"deepinfra",
6061
"gemini-cli",
6162
"virtual-quota-fallback",
6263
"qwen-code",
@@ -320,6 +321,12 @@ const kilocodeSchema = baseProviderSettingsSchema.extend({
320321
openRouterProviderSort: openRouterProviderSortSchema.optional(),
321322
})
322323

324+
const deepInfraSchema = apiModelIdProviderModelSchema.extend({
325+
deepInfraBaseUrl: z.string().optional(),
326+
deepInfraApiKey: z.string().optional(),
327+
deepInfraModelId: z.string().optional(),
328+
})
329+
323330
export const virtualQuotaFallbackProfileDataSchema = z.object({
324331
profileName: z.string().optional(),
325332
profileId: z.string().optional(),
@@ -393,6 +400,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
393400
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
394401
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
395402
// kilocode_change start
403+
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
396404
geminiCliSchema.merge(z.object({ apiProvider: z.literal("gemini-cli") })),
397405
kilocodeSchema.merge(z.object({ apiProvider: z.literal("kilocode") })),
398406
virtualQuotaFallbackSchema.merge(z.object({ apiProvider: z.literal("virtual-quota-fallback") })),
@@ -430,6 +438,7 @@ export const providerSettingsSchema = z.object({
430438
...kilocodeSchema.shape,
431439
...virtualQuotaFallbackSchema.shape,
432440
...qwenCodeSchema.shape,
441+
...deepInfraSchema.shape,
433442
// kilocode_change end
434443
...openAiNativeSchema.shape,
435444
...mistralSchema.shape,
@@ -478,6 +487,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
478487
"litellmModelId",
479488
"huggingFaceModelId",
480489
"ioIntelligenceModelId",
490+
"deepInfraModelId", // kilocode_change
481491
]
482492

483493
export const getModelId = (settings: ProviderSettings): string | undefined => {
@@ -598,6 +608,7 @@ export const MODELS_BY_PROVIDER: Record<
598608
kilocode: { id: "kilocode", label: "Kilocode", models: [] },
599609
"virtual-quota-fallback": { id: "virtual-quota-fallback", label: "Virtual Quota Fallback", models: [] },
600610
"qwen-code": { id: "qwen-code", label: "Qwen Code", models: [] },
611+
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
601612
// kilocode_change end
602613
}
603614

@@ -611,6 +622,7 @@ export const dynamicProviders = [
611622
// kilocode_change start
612623
"kilocode",
613624
"virtual-quota-fallback",
625+
"deepinfra",
614626
// kilocode_change end
615627
] as const satisfies readonly ProviderName[]
616628

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// kilocode_change: provider added
2+
3+
import type { ModelInfo } from "../model.js"
4+
5+
// Default fallback values for DeepInfra when model metadata is not yet loaded.
6+
export const deepInfraDefaultModelId = "Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo"
7+
8+
export const deepInfraDefaultModelInfo: ModelInfo = {
9+
maxTokens: 16384,
10+
contextWindow: 262144,
11+
supportsImages: false,
12+
supportsPromptCache: false,
13+
inputPrice: 0.3,
14+
outputPrice: 1.2,
15+
description: "Qwen 3 Coder 480B A35B Instruct Turbo model, 256K context.",
16+
}

packages/types/src/providers/index.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ export * from "./doubao.js"
88
export * from "./featherless.js"
99
export * from "./fireworks.js"
1010
export * from "./gemini.js"
11-
export * from "./gemini-cli.js" // kilocode_change
11+
// kilocode_change start
12+
export * from "./gemini-cli.js"
13+
export * from "./qwen-code.js"
14+
export * from "./deepinfra.js"
15+
// kilocode_change end
1216
export * from "./glama.js"
1317
export * from "./groq.js"
1418
export * from "./huggingface.js"
@@ -21,7 +25,6 @@ export * from "./ollama.js"
2125
export * from "./openai.js"
2226
export * from "./openrouter.js"
2327
export * from "./requesty.js"
24-
export * from "./qwen-code.js" // kilocode_change
2528
export * from "./roo.js"
2629
export * from "./sambanova.js"
2730
export * from "./unbound.js"

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import {
3434
VirtualQuotaFallbackHandler,
3535
GeminiCliHandler,
3636
QwenCodeHandler,
37+
DeepInfraHandler,
3738
// kilocode_change end
3839
ClaudeCodeHandler,
3940
SambaNovaHandler,
@@ -98,6 +99,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
9899
return new VirtualQuotaFallbackHandler(options)
99100
case "qwen-code":
100101
return new QwenCodeHandler(options)
102+
case "deepinfra":
103+
return new DeepInfraHandler(options)
101104
// kilocode_change end
102105
case "anthropic":
103106
return new AnthropicHandler(options)

src/api/providers/deepinfra.ts

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// kilocode_change - provider added
2+
3+
import { Anthropic } from "@anthropic-ai/sdk" // for message param types
4+
import OpenAI from "openai"
5+
6+
import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types"
7+
8+
import type { ApiHandlerOptions } from "../../shared/api"
9+
import { calculateApiCostOpenAI } from "../../shared/cost"
10+
11+
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
12+
import { convertToOpenAiMessages } from "../transform/openai-format"
13+
14+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
15+
import { RouterProvider } from "./router-provider"
16+
import { getModelParams } from "../transform/model-params"
17+
import { getModels } from "./fetchers/modelCache"
18+
19+
/**
20+
* DeepInfra provider handler (OpenAI compatible)
21+
*/
22+
export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler {
23+
constructor(options: ApiHandlerOptions) {
24+
super({
25+
options: {
26+
...options,
27+
openAiHeaders: {
28+
"X-Deepinfra-Source": "kilocode",
29+
"X-Deepinfra-Version": `2025-08-25`,
30+
},
31+
},
32+
name: "deepinfra",
33+
baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`,
34+
apiKey: options.deepInfraApiKey || "not-provided",
35+
modelId: options.deepInfraModelId,
36+
defaultModelId: deepInfraDefaultModelId,
37+
defaultModelInfo: deepInfraDefaultModelInfo,
38+
})
39+
}
40+
41+
public override async fetchModel() {
42+
this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
43+
return this.getModel()
44+
}
45+
46+
override getModel() {
47+
const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId
48+
const info = this.models[id] ?? deepInfraDefaultModelInfo
49+
50+
const params = getModelParams({
51+
format: "openai",
52+
modelId: id,
53+
model: info,
54+
settings: this.options,
55+
})
56+
57+
return { id, info, ...params }
58+
}
59+
60+
override async *createMessage(
61+
systemPrompt: string,
62+
messages: Anthropic.Messages.MessageParam[],
63+
_metadata?: ApiHandlerCreateMessageMetadata,
64+
): ApiStream {
65+
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
66+
let prompt_cache_key = undefined
67+
if (info.supportsPromptCache && _metadata?.taskId) {
68+
prompt_cache_key = _metadata.taskId
69+
}
70+
71+
const requestOptions = {
72+
model: modelId,
73+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
74+
stream: true,
75+
stream_options: { include_usage: true },
76+
reasoning_effort,
77+
prompt_cache_key,
78+
} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
79+
80+
if (this.supportsTemperature(modelId)) {
81+
requestOptions.temperature = this.options.modelTemperature ?? 0
82+
}
83+
84+
// If includeMaxTokens is enabled, set a cap using model info
85+
if (this.options.includeMaxTokens === true && info.maxTokens) {
86+
// Prefer modern OpenAI param when available in SDK
87+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
88+
}
89+
90+
const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse()
91+
92+
let lastUsage: OpenAI.CompletionUsage | undefined
93+
for await (const chunk of stream) {
94+
const delta = chunk.choices[0]?.delta
95+
96+
if (delta?.content) {
97+
yield { type: "text", text: delta.content }
98+
}
99+
100+
if (delta && "reasoning_content" in delta && delta.reasoning_content) {
101+
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
102+
}
103+
104+
if (chunk.usage) {
105+
lastUsage = chunk.usage
106+
}
107+
}
108+
109+
if (lastUsage) {
110+
yield this.processUsageMetrics(lastUsage, info)
111+
}
112+
}
113+
114+
async completePrompt(prompt: string): Promise<string> {
115+
const { id: modelId, info } = await this.fetchModel()
116+
117+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
118+
model: modelId,
119+
messages: [{ role: "user", content: prompt }],
120+
}
121+
if (this.supportsTemperature(modelId)) {
122+
requestOptions.temperature = this.options.modelTemperature ?? 0
123+
}
124+
if (this.options.includeMaxTokens === true && info.maxTokens) {
125+
;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens
126+
}
127+
128+
const resp = await this.client.chat.completions.create(requestOptions)
129+
return resp.choices[0]?.message?.content || ""
130+
}
131+
132+
protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
133+
const inputTokens = usage?.prompt_tokens || 0
134+
const outputTokens = usage?.completion_tokens || 0
135+
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
136+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
137+
138+
const totalCost = modelInfo
139+
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
140+
: 0
141+
142+
return {
143+
type: "usage",
144+
inputTokens,
145+
outputTokens,
146+
cacheWriteTokens: cacheWriteTokens || undefined,
147+
cacheReadTokens: cacheReadTokens || undefined,
148+
totalCost,
149+
}
150+
}
151+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import axios from "axios"
2+
import { z } from "zod"
3+
4+
import { type ModelInfo } from "@roo-code/types"
5+
6+
import { DEFAULT_HEADERS } from "../constants"
7+
8+
// DeepInfra models endpoint follows OpenAI /models shape with an added metadata object.
9+
// Use only the supported fields and infer capabilities from tags.
10+
11+
const DeepInfraModelSchema = z.object({
12+
id: z.string(),
13+
object: z.literal("model"),
14+
owned_by: z.string().optional(),
15+
created: z.number().optional(),
16+
root: z.string().optional(),
17+
metadata: z
18+
.object({
19+
description: z.string().optional(),
20+
context_length: z.number().optional(),
21+
max_tokens: z.number().optional(),
22+
tags: z.array(z.string()).optional(), // e.g., ["vision", "prompt_cache"]
23+
pricing: z
24+
.object({
25+
input_tokens: z.number().optional(),
26+
output_tokens: z.number().optional(),
27+
cache_read_tokens: z.number().optional(),
28+
})
29+
.optional(),
30+
})
31+
.optional(),
32+
})
33+
34+
const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSchema) })
35+
36+
export async function getDeepInfraModels(
37+
apiKey?: string,
38+
baseUrl: string = "https://api.deepinfra.com/v1/openai",
39+
): Promise<Record<string, ModelInfo>> {
40+
const headers: Record<string, string> = { ...DEFAULT_HEADERS }
41+
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`
42+
43+
const url = `${baseUrl.replace(/\/$/, "")}/models`
44+
const models: Record<string, ModelInfo> = {}
45+
46+
const response = await axios.get(url, { headers })
47+
const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
48+
const data = parsed.success ? parsed.data.data : response.data?.data || []
49+
50+
for (const m of data as Array<z.infer<typeof DeepInfraModelSchema>>) {
51+
const meta = m.metadata || {}
52+
const tags = meta.tags || []
53+
54+
const contextWindow = typeof meta.context_length === "number" ? meta.context_length : 8192
55+
const maxTokens = typeof meta.max_tokens === "number" ? meta.max_tokens : Math.ceil(contextWindow * 0.2)
56+
57+
const info: ModelInfo = {
58+
maxTokens,
59+
contextWindow,
60+
supportsImages: tags.includes("vision"),
61+
supportsPromptCache: tags.includes("prompt_cache"),
62+
supportsReasoningEffort: tags.includes("reasoning_effort"),
63+
inputPrice: meta.pricing?.input_tokens,
64+
outputPrice: meta.pricing?.output_tokens,
65+
cacheReadsPrice: meta.pricing?.cache_read_tokens,
66+
description: meta.description,
67+
}
68+
69+
models[m.id] = info
70+
}
71+
72+
return models
73+
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { getKiloBaseUriFromToken } from "../../../shared/kilocode/token"
1919
import { getOllamaModels } from "./ollama"
2020
import { getLMStudioModels } from "./lmstudio"
2121
import { getIOIntelligenceModels } from "./io-intelligence"
22+
import { getDeepInfraModels } from "./deepinfra" // kilocode_change
2223
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2324

2425
export /*kilocode_change*/ async function writeModels(router: RouterName, data: ModelRecord) {
@@ -89,6 +90,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
8990
headers: options.kilocodeToken ? { Authorization: `Bearer ${options.kilocodeToken}` } : undefined,
9091
})
9192
break
93+
case "deepinfra":
94+
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
95+
break
9296
case "cerebras":
9397
models = cerebrasModels
9498
break

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export { SambaNovaHandler } from "./sambanova"
2626
export { UnboundHandler } from "./unbound"
2727
export { VertexHandler } from "./vertex"
2828
// kilocode_change start
29+
export { DeepInfraHandler } from "./deepinfra"
2930
export { GeminiCliHandler } from "./gemini-cli"
3031
export { QwenCodeHandler } from "./qwen-code"
3132
export { VirtualQuotaFallbackHandler } from "./virtual-quota-fallback"

0 commit comments

Comments
 (0)