Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
export const providerNames = [
"anthropic",
"claude-code",
"deepinfra",
"glama",
"openrouter",
"bedrock",
Expand Down Expand Up @@ -294,6 +295,11 @@ const cerebrasSchema = apiModelIdProviderModelSchema.extend({
cerebrasApiKey: z.string().optional(),
})

const deepInfraSchema = baseProviderSettingsSchema.extend({
deepInfraApiKey: z.string().optional(),
deepInfraModelId: z.string().optional(),
})

const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
sambaNovaApiKey: z.string().optional(),
})
Expand Down Expand Up @@ -336,6 +342,7 @@ const defaultSchema = z.object({
export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [
anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })),
claudeCodeSchema.merge(z.object({ apiProvider: z.literal("claude-code") })),
deepInfraSchema.merge(z.object({ apiProvider: z.literal("deepinfra") })),
glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })),
openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })),
bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })),
Expand Down Expand Up @@ -376,6 +383,7 @@ export const providerSettingsSchema = z.object({
apiProvider: providerNamesSchema.optional(),
...anthropicSchema.shape,
...claudeCodeSchema.shape,
...deepInfraSchema.shape,
...glamaSchema.shape,
...openRouterSchema.shape,
...bedrockSchema.shape,
Expand Down Expand Up @@ -426,6 +434,7 @@ export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options

export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
"apiModelId",
"deepInfraModelId",
"glamaModelId",
"openRouterModelId",
"openAiModelId",
Expand Down Expand Up @@ -489,6 +498,7 @@ export const MODELS_BY_PROVIDER: Record<
label: "Chutes AI",
models: Object.keys(chutesModels),
},
deepinfra: { id: "deepinfra", label: "DeepInfra", models: [] },
"claude-code": { id: "claude-code", label: "Claude Code", models: Object.keys(claudeCodeModels) },
deepseek: {
id: "deepseek",
Expand Down Expand Up @@ -563,6 +573,7 @@ export const MODELS_BY_PROVIDER: Record<
}

export const dynamicProviders = [
"deepinfra",
"glama",
"huggingface",
"litellm",
Expand Down
12 changes: 12 additions & 0 deletions packages/types/src/providers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import type { ModelInfo } from "../model.js"

// DeepInfra models are fetched dynamically from their API
// This type represents the model IDs that will be available
export type DeepInfraModelId = string

// Default model to use when none is specified
export const deepInfraDefaultModelId: DeepInfraModelId = "meta-llama/Llama-3.3-70B-Instruct"

// DeepInfra models will be fetched dynamically, so we provide an empty object
// The actual models will be populated at runtime via the API
export const deepInfraModels = {} as const satisfies Record<string, ModelInfo>
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export * from "./bedrock.js"
export * from "./cerebras.js"
export * from "./chutes.js"
export * from "./claude-code.js"
export * from "./deepinfra.js"
export * from "./deepseek.js"
export * from "./doubao.js"
export * from "./featherless.js"
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
AnthropicHandler,
AwsBedrockHandler,
CerebrasHandler,
DeepInfraHandler,
OpenRouterHandler,
VertexHandler,
AnthropicVertexHandler,
Expand Down Expand Up @@ -114,6 +115,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new GeminiHandler(options)
case "openai-native":
return new OpenAiNativeHandler(options)
case "deepinfra":
return new DeepInfraHandler(options)
case "deepseek":
return new DeepSeekHandler(options)
case "doubao":
Expand Down
102 changes: 102 additions & 0 deletions src/api/providers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import { type DeepInfraModelId, deepInfraDefaultModelId } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import type { ApiHandlerCreateMessageMetadata } from "../index"
import type { ModelInfo } from "@roo-code/types"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { calculateApiCostOpenAI } from "../../shared/cost"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

// Enhanced usage interface to support DeepInfra's cached token fields
interface DeepInfraUsage extends OpenAI.CompletionUsage {
prompt_tokens_details?: {
cached_tokens?: number
}
}

export class DeepInfraHandler extends BaseOpenAiCompatibleProvider<DeepInfraModelId> {
constructor(options: ApiHandlerOptions) {
// Initialize with empty models, will be populated dynamically
super({
...options,
providerName: "DeepInfra",
baseURL: "https://api.deepinfra.com/v1/openai",
apiKey: options.deepInfraApiKey,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? The constructor doesn't validate if the API key is provided, unlike the base class pattern used by other providers. This could lead to runtime errors when the API key is missing.

Consider adding validation:

Suggested change
apiKey: options.deepInfraApiKey,
if (!options.deepInfraApiKey) {
throw new Error("DeepInfra API key is required")
}

defaultProviderModelId: deepInfraDefaultModelId,
providerModels: {},
defaultTemperature: 0.7,
})
}

override getModel() {
const modelId = this.options.deepInfraModelId || deepInfraDefaultModelId

// For DeepInfra, we use a default model configuration
// The actual model info will be fetched dynamically via the fetcher
const defaultModelInfo: ModelInfo = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method returns hardcoded default model info instead of using actual fetched model data from the router models. Could we fetch the actual model info dynamically like OpenRouter does?

The current approach might not reflect the actual model capabilities and pricing.

maxTokens: 4096,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.15,
outputPrice: 0.6,
cacheReadsPrice: 0.075, // 50% discount for cached tokens
description: "DeepInfra model",
}

return { id: modelId, info: defaultModelInfo }
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield* this.yieldUsage(chunk.usage as DeepInfraUsage)
}
}
}

private async *yieldUsage(usage: DeepInfraUsage | undefined): ApiStream {
const { info } = this.getModel()
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0

const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

// DeepInfra does not track cache writes separately
const cacheWriteTokens = 0

// Calculate cost using OpenAI-compatible cost calculation
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)

// Calculate non-cached input tokens for proper reporting
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)

yield {
type: "usage",
inputTokens: nonCachedInputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
totalCost,
}
}
}
150 changes: 150 additions & 0 deletions src/api/providers/fetchers/deepinfra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import axios from "axios"
import { z } from "zod"

import type { ModelInfo } from "@roo-code/types"

import { parseApiPrice } from "../../../shared/cost"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is unused. The parseApiPrice function isn't called anywhere in this file.


/**
* DeepInfra Model Schema
*/
const deepInfraModelSchema = z.object({
model_name: z.string(),
type: z.string().optional(),
max_tokens: z.number().optional(),
context_length: z.number().optional(),
pricing: z
.object({
input: z.number().optional(),
output: z.number().optional(),
cached_input: z.number().optional(),
})
.optional(),
description: z.string().optional(),
capabilities: z.array(z.string()).optional(),
})

type DeepInfraModel = z.infer<typeof deepInfraModelSchema>

/**
* DeepInfra Models Response Schema
*/
const deepInfraModelsResponseSchema = z.array(deepInfraModelSchema)

type DeepInfraModelsResponse = z.infer<typeof deepInfraModelsResponseSchema>

/**
* Fetch models from DeepInfra API
*/
export async function getDeepInfraModels(apiKey?: string): Promise<Record<string, ModelInfo>> {
const models: Record<string, ModelInfo> = {}
const baseURL = "https://api.deepinfra.com/v1/openai"

try {
// DeepInfra requires authentication to fetch models
if (!apiKey) {
console.log("DeepInfra API key not provided, returning empty models")
return models
}

const response = await axios.get<DeepInfraModelsResponse>(`${baseURL}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
})

const result = deepInfraModelsResponseSchema.safeParse(response.data)
const data = result.success ? result.data : response.data

if (!result.success) {
console.error("DeepInfra models response is invalid", result.error.format())
}

// Process each model from the response
for (const model of data) {
// Skip non-text models
if (model.type && !["text", "chat", "instruct"].includes(model.type)) {
continue
}

const modelInfo: ModelInfo = {
maxTokens: model.max_tokens || 4096,
contextWindow: model.context_length || 32768,
supportsImages: model.capabilities?.includes("vision") || false,
supportsPromptCache: true, // DeepInfra supports prompt caching
inputPrice: model.pricing?.input ? model.pricing.input / 1000000 : 0.15, // Convert from per million to per token
outputPrice: model.pricing?.output ? model.pricing.output / 1000000 : 0.6,
cacheReadsPrice: model.pricing?.cached_input ? model.pricing.cached_input / 1000000 : undefined,
description: model.description,
}

models[model.model_name] = modelInfo
}

// If the API doesn't return models, provide some default popular models
if (Object.keys(models).length === 0) {
console.log("No models returned from DeepInfra API, using default models")

// Default popular models on DeepInfra
models["meta-llama/Llama-3.3-70B-Instruct"] = {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.35 / 1000000,
outputPrice: 0.4 / 1000000,
cacheReadsPrice: 0.175 / 1000000,
description: "Meta Llama 3.3 70B Instruct model",
}

models["meta-llama/Llama-3.1-8B-Instruct"] = {
maxTokens: 4096,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.06 / 1000000,
outputPrice: 0.06 / 1000000,
cacheReadsPrice: 0.03 / 1000000,
description: "Meta Llama 3.1 8B Instruct model",
}

models["Qwen/Qwen2.5-72B-Instruct"] = {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.35 / 1000000,
outputPrice: 0.4 / 1000000,
cacheReadsPrice: 0.175 / 1000000,
description: "Qwen 2.5 72B Instruct model",
}

models["mistralai/Mixtral-8x7B-Instruct-v0.1"] = {
maxTokens: 4096,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.24 / 1000000,
outputPrice: 0.24 / 1000000,
cacheReadsPrice: 0.12 / 1000000,
description: "Mistral Mixtral 8x7B Instruct model",
}
}
} catch (error) {
console.error(`Error fetching DeepInfra models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling here uses console.error but doesn't follow the same pattern as other fetchers. Consider handling errors more gracefully without exposing the full error object structure in logs.


// Return default models on error
models["meta-llama/Llama-3.3-70B-Instruct"] = {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.35 / 1000000,
outputPrice: 0.4 / 1000000,
cacheReadsPrice: 0.175 / 1000000,
description: "Meta Llama 3.3 70B Instruct model",
}
}

return models
}
5 changes: 5 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { GetModelsOptions } from "../../../shared/api"
import { getOllamaModels } from "./ollama"
import { getLMStudioModels } from "./lmstudio"
import { getIOIntelligenceModels } from "./io-intelligence"
import { getDeepInfraModels } from "./deepinfra"
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

async function writeModels(router: RouterName, data: ModelRecord) {
Expand Down Expand Up @@ -55,6 +56,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>

try {
switch (provider) {
case "deepinfra":
// DeepInfra models endpoint requires an API key
models = await getDeepInfraModels(options.apiKey)
break
case "openrouter":
models = await getOpenRouterModels()
break
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export { AwsBedrockHandler } from "./bedrock"
export { CerebrasHandler } from "./cerebras"
export { ChutesHandler } from "./chutes"
export { ClaudeCodeHandler } from "./claude-code"
export { DeepInfraHandler } from "./deepinfra"
export { DeepSeekHandler } from "./deepseek"
export { DoubaoHandler } from "./doubao"
export { MoonshotHandler } from "./moonshot"
Expand Down
2 changes: 2 additions & 0 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export type ApiHandlerOptions = Omit<ProviderSettings, "apiProvider"> & {
// RouterName

const routerNames = [
"deepinfra",
"openrouter",
"requesty",
"glama",
Expand Down Expand Up @@ -144,6 +145,7 @@ export const getModelMaxOutputTokens = ({
// GetModelsOptions

export type GetModelsOptions =
| { provider: "deepinfra"; apiKey?: string }
| { provider: "openrouter" }
| { provider: "glama" }
| { provider: "requesty"; apiKey?: string; baseUrl?: string }
Expand Down
Loading
Loading