Skip to content

Commit 3c24584

Browse files
committed
feat: add DeepInfra as a model provider
- Add DeepInfra types and model definitions - Implement DeepInfra provider handler with OpenAI-compatible API - Add dynamic model fetching from DeepInfra API - Support prompt caching for reduced costs - Update UI components to support DeepInfra selection - Add DeepInfra to router configuration Implements #7661
1 parent 966ed76 commit 3c24584

File tree

11 files changed

+295
-1
lines changed

11 files changed

+295
-1
lines changed

packages/types/src/provider-settings.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import {
3434
export const providerNames = [
3535
"anthropic",
3636
"claude-code",
37+
"deepinfra",
3738
"glama",
3839
"openrouter",
3940
"bedrock",
@@ -294,6 +295,11 @@ const cerebrasSchema = apiModelIdProviderModelSchema.extend({
294295
cerebrasApiKey: z.string().optional(),
295296
})
296297

298+
const deepInfraSchema = baseProviderSettingsSchema.extend({
299+
deepInfraApiKey: z.string().optional(),
300+
deepInfraModelId: z.string().optional(),
301+
})
302+
297303
const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
298304
sambaNovaApiKey: z.string().optional(),
299305
})
@@ -336,6 +342,7 @@ const defaultSchema = z.object({
336342
export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [
337343
anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })),
338344
claudeCodeSchema.merge(z.object({ apiProvider: z.literal("claude-code") })),
345+
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
339346
glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })),
340347
openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })),
341348
bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })),
@@ -376,6 +383,7 @@ export const providerSettingsSchema = z.object({
376383
apiProvider: providerNamesSchema.optional(),
377384
...anthropicSchema.shape,
378385
...claudeCodeSchema.shape,
386+
...deepInfraSchema.shape,
379387
...glamaSchema.shape,
380388
...openRouterSchema.shape,
381389
...bedrockSchema.shape,
@@ -426,6 +434,7 @@ export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options
426434

427435
export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
428436
"apiModelId",
437+
"deepInfraModelId",
429438
"glamaModelId",
430439
"openRouterModelId",
431440
"openAiModelId",
@@ -489,6 +498,7 @@ export const MODELS_BY_PROVIDER: Record<
489498
label: "Chutes AI",
490499
models: Object.keys(chutesModels),
491500
},
501+
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
492502
"claude-code": { id: "claude-code", label: "Claude Code", models: Object.keys(claudeCodeModels) },
493503
deepseek: {
494504
id: "deepseek",
@@ -563,6 +573,7 @@ export const MODELS_BY_PROVIDER: Record<
563573
}
564574

565575
export const dynamicProviders = [
576+
"deepinfra",
566577
"glama",
567578
"huggingface",
568579
"litellm",
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// DeepInfra models are fetched dynamically from their API
4+
// This type represents the model IDs that will be available
5+
export type DeepInfraModelId = string
6+
7+
// Default model to use when none is specified
8+
export const deepInfraDefaultModelId: DeepInfraModelId = "meta-llama/Llama-3.3-70B-Instruct"
9+
10+
// DeepInfra models will be fetched dynamically, so we provide an empty object
11+
// The actual models will be populated at runtime via the API
12+
export const deepInfraModels = {} as const satisfies Record<string, ModelInfo>

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export * from "./bedrock.js"
33
export * from "./cerebras.js"
44
export * from "./chutes.js"
55
export * from "./claude-code.js"
6+
export * from "./deepinfra.js"
67
export * from "./deepseek.js"
78
export * from "./doubao.js"
89
export * from "./featherless.js"

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
AnthropicHandler,
1010
AwsBedrockHandler,
1111
CerebrasHandler,
12+
DeepInfraHandler,
1213
OpenRouterHandler,
1314
VertexHandler,
1415
AnthropicVertexHandler,
@@ -114,6 +115,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
114115
return new GeminiHandler(options)
115116
case "openai-native":
116117
return new OpenAiNativeHandler(options)
118+
case "deepinfra":
119+
return new DeepInfraHandler(options)
117120
case "deepseek":
118121
return new DeepSeekHandler(options)
119122
case "doubao":

src/api/providers/deepinfra.ts

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import { type DeepInfraModelId, deepInfraDefaultModelId } from "@roo-code/types"
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
import OpenAI from "openai"
4+
5+
import type { ApiHandlerOptions } from "../../shared/api"
6+
import type { ApiHandlerCreateMessageMetadata } from "../index"
7+
import type { ModelInfo } from "@roo-code/types"
8+
import { ApiStream } from "../transform/stream"
9+
import { convertToOpenAiMessages } from "../transform/openai-format"
10+
import { calculateApiCostOpenAI } from "../../shared/cost"
11+
12+
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
13+
14+
// Enhanced usage interface to support DeepInfra's cached token fields
15+
interface DeepInfraUsage extends OpenAI.CompletionUsage {
16+
prompt_tokens_details?: {
17+
cached_tokens?: number
18+
}
19+
}
20+
21+
export class DeepInfraHandler extends BaseOpenAiCompatibleProvider<DeepInfraModelId> {
22+
constructor(options: ApiHandlerOptions) {
23+
// Initialize with empty models, will be populated dynamically
24+
super({
25+
...options,
26+
providerName: "DeepInfra",
27+
baseURL: "https://api.deepinfra.com/v1/openai",
28+
apiKey: options.deepInfraApiKey,
29+
defaultProviderModelId: deepInfraDefaultModelId,
30+
providerModels: {},
31+
defaultTemperature: 0.7,
32+
})
33+
}
34+
35+
override getModel() {
36+
const modelId = this.options.deepInfraModelId || deepInfraDefaultModelId
37+
38+
// For DeepInfra, we use a default model configuration
39+
// The actual model info will be fetched dynamically via the fetcher
40+
const defaultModelInfo: ModelInfo = {
41+
maxTokens: 4096,
42+
contextWindow: 32768,
43+
supportsImages: false,
44+
supportsPromptCache: true,
45+
inputPrice: 0.15,
46+
outputPrice: 0.6,
47+
cacheReadsPrice: 0.075, // 50% discount for cached tokens
48+
description: "DeepInfra model",
49+
}
50+
51+
return { id: modelId, info: defaultModelInfo }
52+
}
53+
54+
override async *createMessage(
55+
systemPrompt: string,
56+
messages: Anthropic.Messages.MessageParam[],
57+
metadata?: ApiHandlerCreateMessageMetadata,
58+
): ApiStream {
59+
const stream = await this.createStream(systemPrompt, messages, metadata)
60+
61+
for await (const chunk of stream) {
62+
const delta = chunk.choices[0]?.delta
63+
64+
if (delta?.content) {
65+
yield {
66+
type: "text",
67+
text: delta.content,
68+
}
69+
}
70+
71+
if (chunk.usage) {
72+
yield* this.yieldUsage(chunk.usage as DeepInfraUsage)
73+
}
74+
}
75+
}
76+
77+
private async *yieldUsage(usage: DeepInfraUsage | undefined): ApiStream {
78+
const { info } = this.getModel()
79+
const inputTokens = usage?.prompt_tokens || 0
80+
const outputTokens = usage?.completion_tokens || 0
81+
82+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
83+
84+
// DeepInfra does not track cache writes separately
85+
const cacheWriteTokens = 0
86+
87+
// Calculate cost using OpenAI-compatible cost calculation
88+
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
89+
90+
// Calculate non-cached input tokens for proper reporting
91+
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
92+
93+
yield {
94+
type: "usage",
95+
inputTokens: nonCachedInputTokens,
96+
outputTokens,
97+
cacheWriteTokens,
98+
cacheReadTokens,
99+
totalCost,
100+
}
101+
}
102+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import axios from "axios"
2+
import { z } from "zod"
3+
4+
import type { ModelInfo } from "@roo-code/types"
5+
6+
import { parseApiPrice } from "../../../shared/cost"
7+
8+
/**
9+
* DeepInfra Model Schema
10+
*/
11+
const deepInfraModelSchema = z.object({
12+
model_name: z.string(),
13+
type: z.string().optional(),
14+
max_tokens: z.number().optional(),
15+
context_length: z.number().optional(),
16+
pricing: z
17+
.object({
18+
input: z.number().optional(),
19+
output: z.number().optional(),
20+
cached_input: z.number().optional(),
21+
})
22+
.optional(),
23+
description: z.string().optional(),
24+
capabilities: z.array(z.string()).optional(),
25+
})
26+
27+
type DeepInfraModel = z.infer<typeof deepInfraModelSchema>
28+
29+
/**
30+
* DeepInfra Models Response Schema
31+
*/
32+
const deepInfraModelsResponseSchema = z.array(deepInfraModelSchema)
33+
34+
type DeepInfraModelsResponse = z.infer<typeof deepInfraModelsResponseSchema>
35+
36+
/**
37+
* Fetch models from DeepInfra API
38+
*/
39+
export async function getDeepInfraModels(apiKey?: string): Promise<Record<string, ModelInfo>> {
40+
const models: Record<string, ModelInfo> = {}
41+
const baseURL = "https://api.deepinfra.com/v1/openai"
42+
43+
try {
44+
// DeepInfra requires authentication to fetch models
45+
if (!apiKey) {
46+
console.log("DeepInfra API key not provided, returning empty models")
47+
return models
48+
}
49+
50+
const response = await axios.get<DeepInfraModelsResponse>(`${baseURL}/models`, {
51+
headers: {
52+
Authorization: `Bearer ${apiKey}`,
53+
},
54+
})
55+
56+
const result = deepInfraModelsResponseSchema.safeParse(response.data)
57+
const data = result.success ? result.data : response.data
58+
59+
if (!result.success) {
60+
console.error("DeepInfra models response is invalid", result.error.format())
61+
}
62+
63+
// Process each model from the response
64+
for (const model of data) {
65+
// Skip non-text models
66+
if (model.type && !["text", "chat", "instruct"].includes(model.type)) {
67+
continue
68+
}
69+
70+
const modelInfo: ModelInfo = {
71+
maxTokens: model.max_tokens || 4096,
72+
contextWindow: model.context_length || 32768,
73+
supportsImages: model.capabilities?.includes("vision") || false,
74+
supportsPromptCache: true, // DeepInfra supports prompt caching
75+
inputPrice: model.pricing?.input ? model.pricing.input / 1000000 : 0.15, // Convert from per million to per token
76+
outputPrice: model.pricing?.output ? model.pricing.output / 1000000 : 0.6,
77+
cacheReadsPrice: model.pricing?.cached_input ? model.pricing.cached_input / 1000000 : undefined,
78+
description: model.description,
79+
}
80+
81+
models[model.model_name] = modelInfo
82+
}
83+
84+
// If the API doesn't return models, provide some default popular models
85+
if (Object.keys(models).length === 0) {
86+
console.log("No models returned from DeepInfra API, using default models")
87+
88+
// Default popular models on DeepInfra
89+
models["meta-llama/Llama-3.3-70B-Instruct"] = {
90+
maxTokens: 8192,
91+
contextWindow: 131072,
92+
supportsImages: false,
93+
supportsPromptCache: true,
94+
inputPrice: 0.35 / 1000000,
95+
outputPrice: 0.4 / 1000000,
96+
cacheReadsPrice: 0.175 / 1000000,
97+
description: "Meta Llama 3.3 70B Instruct model",
98+
}
99+
100+
models["meta-llama/Llama-3.1-8B-Instruct"] = {
101+
maxTokens: 4096,
102+
contextWindow: 131072,
103+
supportsImages: false,
104+
supportsPromptCache: true,
105+
inputPrice: 0.06 / 1000000,
106+
outputPrice: 0.06 / 1000000,
107+
cacheReadsPrice: 0.03 / 1000000,
108+
description: "Meta Llama 3.1 8B Instruct model",
109+
}
110+
111+
models["Qwen/Qwen2.5-72B-Instruct"] = {
112+
maxTokens: 8192,
113+
contextWindow: 131072,
114+
supportsImages: false,
115+
supportsPromptCache: true,
116+
inputPrice: 0.35 / 1000000,
117+
outputPrice: 0.4 / 1000000,
118+
cacheReadsPrice: 0.175 / 1000000,
119+
description: "Qwen 2.5 72B Instruct model",
120+
}
121+
122+
models["mistralai/Mixtral-8x7B-Instruct-v0.1"] = {
123+
maxTokens: 4096,
124+
contextWindow: 32768,
125+
supportsImages: false,
126+
supportsPromptCache: true,
127+
inputPrice: 0.24 / 1000000,
128+
outputPrice: 0.24 / 1000000,
129+
cacheReadsPrice: 0.12 / 1000000,
130+
description: "Mistral Mixtral 8x7B Instruct model",
131+
}
132+
}
133+
} catch (error) {
134+
console.error(`Error fetching DeepInfra models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`)
135+
136+
// Return default models on error
137+
models["meta-llama/Llama-3.3-70B-Instruct"] = {
138+
maxTokens: 8192,
139+
contextWindow: 131072,
140+
supportsImages: false,
141+
supportsPromptCache: true,
142+
inputPrice: 0.35 / 1000000,
143+
outputPrice: 0.4 / 1000000,
144+
cacheReadsPrice: 0.175 / 1000000,
145+
description: "Meta Llama 3.3 70B Instruct model",
146+
}
147+
}
148+
149+
return models
150+
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api"
1919
import { getOllamaModels } from "./ollama"
2020
import { getLMStudioModels } from "./lmstudio"
2121
import { getIOIntelligenceModels } from "./io-intelligence"
22+
import { getDeepInfraModels } from "./deepinfra"
2223
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2324

2425
async function writeModels(router: RouterName, data: ModelRecord) {
@@ -55,6 +56,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
5556

5657
try {
5758
switch (provider) {
59+
case "deepinfra":
60+
// DeepInfra models endpoint requires an API key
61+
models = await getDeepInfraModels(options.apiKey)
62+
break
5863
case "openrouter":
5964
models = await getOpenRouterModels()
6065
break

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export { AwsBedrockHandler } from "./bedrock"
44
export { CerebrasHandler } from "./cerebras"
55
export { ChutesHandler } from "./chutes"
66
export { ClaudeCodeHandler } from "./claude-code"
7+
export { DeepInfraHandler } from "./deepinfra"
78
export { DeepSeekHandler } from "./deepseek"
89
export { DoubaoHandler } from "./doubao"
910
export { MoonshotHandler } from "./moonshot"

src/shared/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export type ApiHandlerOptions = Omit<ProviderSettings, "apiProvider"> & {
1919
// RouterName
2020

2121
const routerNames = [
22+
"deepinfra",
2223
"openrouter",
2324
"requesty",
2425
"glama",
@@ -144,6 +145,7 @@ export const getModelMaxOutputTokens = ({
144145
// GetModelsOptions
145146

146147
export type GetModelsOptions =
148+
| { provider: "deepinfra"; apiKey?: string }
147149
| { provider: "openrouter" }
148150
| { provider: "glama" }
149151
| { provider: "requesty"; apiKey?: string; baseUrl?: string }

0 commit comments

Comments
 (0)