-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat: add DeepInfra as a model provider #7662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| import type { ModelInfo } from "../model.js" | ||
|
|
||
| // DeepInfra models are fetched dynamically from their API | ||
| // This type represents the model IDs that will be available | ||
| export type DeepInfraModelId = string | ||
|
|
||
| // Default model to use when none is specified | ||
| export const deepInfraDefaultModelId: DeepInfraModelId = "meta-llama/Llama-3.3-70B-Instruct" | ||
|
|
||
| // DeepInfra models will be fetched dynamically, so we provide an empty object | ||
| // The actual models will be populated at runtime via the API | ||
| export const deepInfraModels = {} as const satisfies Record<string, ModelInfo> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import { type DeepInfraModelId, deepInfraDefaultModelId } from "@roo-code/types" | ||
| import { Anthropic } from "@anthropic-ai/sdk" | ||
| import OpenAI from "openai" | ||
|
|
||
| import type { ApiHandlerOptions } from "../../shared/api" | ||
| import type { ApiHandlerCreateMessageMetadata } from "../index" | ||
| import type { ModelInfo } from "@roo-code/types" | ||
| import { ApiStream } from "../transform/stream" | ||
| import { convertToOpenAiMessages } from "../transform/openai-format" | ||
| import { calculateApiCostOpenAI } from "../../shared/cost" | ||
|
|
||
| import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" | ||
|
|
||
| // Enhanced usage interface to support DeepInfra's cached token fields | ||
| interface DeepInfraUsage extends OpenAI.CompletionUsage { | ||
| prompt_tokens_details?: { | ||
| cached_tokens?: number | ||
| } | ||
| } | ||
|
|
||
| export class DeepInfraHandler extends BaseOpenAiCompatibleProvider<DeepInfraModelId> { | ||
| constructor(options: ApiHandlerOptions) { | ||
| // Initialize with empty models, will be populated dynamically | ||
| super({ | ||
| ...options, | ||
| providerName: "DeepInfra", | ||
| baseURL: "https://api.deepinfra.com/v1/openai", | ||
| apiKey: options.deepInfraApiKey, | ||
| defaultProviderModelId: deepInfraDefaultModelId, | ||
| providerModels: {}, | ||
| defaultTemperature: 0.7, | ||
| }) | ||
| } | ||
|
|
||
| override getModel() { | ||
| const modelId = this.options.deepInfraModelId || deepInfraDefaultModelId | ||
|
|
||
| // For DeepInfra, we use a default model configuration | ||
| // The actual model info will be fetched dynamically via the fetcher | ||
| const defaultModelInfo: ModelInfo = { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method returns hardcoded default model info instead of using actual fetched model data from the router models. Could we fetch the actual model info dynamically like OpenRouter does? The current approach might not reflect the actual model capabilities and pricing. |
||
| maxTokens: 4096, | ||
| contextWindow: 32768, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.15, | ||
| outputPrice: 0.6, | ||
| cacheReadsPrice: 0.075, // 50% discount for cached tokens | ||
| description: "DeepInfra model", | ||
| } | ||
|
|
||
| return { id: modelId, info: defaultModelInfo } | ||
| } | ||
|
|
||
| override async *createMessage( | ||
| systemPrompt: string, | ||
| messages: Anthropic.Messages.MessageParam[], | ||
| metadata?: ApiHandlerCreateMessageMetadata, | ||
| ): ApiStream { | ||
| const stream = await this.createStream(systemPrompt, messages, metadata) | ||
|
|
||
| for await (const chunk of stream) { | ||
| const delta = chunk.choices[0]?.delta | ||
|
|
||
| if (delta?.content) { | ||
| yield { | ||
| type: "text", | ||
| text: delta.content, | ||
| } | ||
| } | ||
|
|
||
| if (chunk.usage) { | ||
| yield* this.yieldUsage(chunk.usage as DeepInfraUsage) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private async *yieldUsage(usage: DeepInfraUsage | undefined): ApiStream { | ||
| const { info } = this.getModel() | ||
| const inputTokens = usage?.prompt_tokens || 0 | ||
| const outputTokens = usage?.completion_tokens || 0 | ||
|
|
||
| const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 | ||
|
|
||
| // DeepInfra does not track cache writes separately | ||
| const cacheWriteTokens = 0 | ||
|
|
||
| // Calculate cost using OpenAI-compatible cost calculation | ||
| const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) | ||
|
|
||
| // Calculate non-cached input tokens for proper reporting | ||
| const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens) | ||
|
|
||
| yield { | ||
| type: "usage", | ||
| inputTokens: nonCachedInputTokens, | ||
| outputTokens, | ||
| cacheWriteTokens, | ||
| cacheReadTokens, | ||
| totalCost, | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import axios from "axios" | ||
| import { z } from "zod" | ||
|
|
||
| import type { ModelInfo } from "@roo-code/types" | ||
|
|
||
| import { parseApiPrice } from "../../../shared/cost" | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This import is unused. The parseApiPrice function isn't called anywhere in this file. |
||
|
|
||
| /** | ||
| * DeepInfra Model Schema | ||
| */ | ||
| const deepInfraModelSchema = z.object({ | ||
| model_name: z.string(), | ||
| type: z.string().optional(), | ||
| max_tokens: z.number().optional(), | ||
| context_length: z.number().optional(), | ||
| pricing: z | ||
| .object({ | ||
| input: z.number().optional(), | ||
| output: z.number().optional(), | ||
| cached_input: z.number().optional(), | ||
| }) | ||
| .optional(), | ||
| description: z.string().optional(), | ||
| capabilities: z.array(z.string()).optional(), | ||
| }) | ||
|
|
||
| type DeepInfraModel = z.infer<typeof deepInfraModelSchema> | ||
|
|
||
| /** | ||
| * DeepInfra Models Response Schema | ||
| */ | ||
| const deepInfraModelsResponseSchema = z.array(deepInfraModelSchema) | ||
|
|
||
| type DeepInfraModelsResponse = z.infer<typeof deepInfraModelsResponseSchema> | ||
|
|
||
| /** | ||
| * Fetch models from DeepInfra API | ||
| */ | ||
| export async function getDeepInfraModels(apiKey?: string): Promise<Record<string, ModelInfo>> { | ||
| const models: Record<string, ModelInfo> = {} | ||
| const baseURL = "https://api.deepinfra.com/v1/openai" | ||
|
|
||
| try { | ||
| // DeepInfra requires authentication to fetch models | ||
| if (!apiKey) { | ||
| console.log("DeepInfra API key not provided, returning empty models") | ||
| return models | ||
| } | ||
|
|
||
| const response = await axios.get<DeepInfraModelsResponse>(`${baseURL}/models`, { | ||
| headers: { | ||
| Authorization: `Bearer ${apiKey}`, | ||
| }, | ||
| }) | ||
|
|
||
| const result = deepInfraModelsResponseSchema.safeParse(response.data) | ||
| const data = result.success ? result.data : response.data | ||
|
|
||
| if (!result.success) { | ||
| console.error("DeepInfra models response is invalid", result.error.format()) | ||
| } | ||
|
|
||
| // Process each model from the response | ||
| for (const model of data) { | ||
| // Skip non-text models | ||
| if (model.type && !["text", "chat", "instruct"].includes(model.type)) { | ||
| continue | ||
| } | ||
|
|
||
| const modelInfo: ModelInfo = { | ||
| maxTokens: model.max_tokens || 4096, | ||
| contextWindow: model.context_length || 32768, | ||
| supportsImages: model.capabilities?.includes("vision") || false, | ||
| supportsPromptCache: true, // DeepInfra supports prompt caching | ||
| inputPrice: model.pricing?.input ? model.pricing.input / 1000000 : 0.15, // Convert from per million to per token | ||
| outputPrice: model.pricing?.output ? model.pricing.output / 1000000 : 0.6, | ||
| cacheReadsPrice: model.pricing?.cached_input ? model.pricing.cached_input / 1000000 : undefined, | ||
| description: model.description, | ||
| } | ||
|
|
||
| models[model.model_name] = modelInfo | ||
| } | ||
|
|
||
| // If the API doesn't return models, provide some default popular models | ||
| if (Object.keys(models).length === 0) { | ||
| console.log("No models returned from DeepInfra API, using default models") | ||
|
|
||
| // Default popular models on DeepInfra | ||
| models["meta-llama/Llama-3.3-70B-Instruct"] = { | ||
| maxTokens: 8192, | ||
| contextWindow: 131072, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.35 / 1000000, | ||
| outputPrice: 0.4 / 1000000, | ||
| cacheReadsPrice: 0.175 / 1000000, | ||
| description: "Meta Llama 3.3 70B Instruct model", | ||
| } | ||
|
|
||
| models["meta-llama/Llama-3.1-8B-Instruct"] = { | ||
| maxTokens: 4096, | ||
| contextWindow: 131072, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.06 / 1000000, | ||
| outputPrice: 0.06 / 1000000, | ||
| cacheReadsPrice: 0.03 / 1000000, | ||
| description: "Meta Llama 3.1 8B Instruct model", | ||
| } | ||
|
|
||
| models["Qwen/Qwen2.5-72B-Instruct"] = { | ||
| maxTokens: 8192, | ||
| contextWindow: 131072, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.35 / 1000000, | ||
| outputPrice: 0.4 / 1000000, | ||
| cacheReadsPrice: 0.175 / 1000000, | ||
| description: "Qwen 2.5 72B Instruct model", | ||
| } | ||
|
|
||
| models["mistralai/Mixtral-8x7B-Instruct-v0.1"] = { | ||
| maxTokens: 4096, | ||
| contextWindow: 32768, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.24 / 1000000, | ||
| outputPrice: 0.24 / 1000000, | ||
| cacheReadsPrice: 0.12 / 1000000, | ||
| description: "Mistral Mixtral 8x7B Instruct model", | ||
| } | ||
| } | ||
| } catch (error) { | ||
| console.error(`Error fetching DeepInfra models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error handling here uses console.error but doesn't follow the same pattern as other fetchers. Consider handling errors more gracefully without exposing the full error object structure in logs. |
||
|
|
||
| // Return default models on error | ||
| models["meta-llama/Llama-3.3-70B-Instruct"] = { | ||
| maxTokens: 8192, | ||
| contextWindow: 131072, | ||
| supportsImages: false, | ||
| supportsPromptCache: true, | ||
| inputPrice: 0.35 / 1000000, | ||
| outputPrice: 0.4 / 1000000, | ||
| cacheReadsPrice: 0.175 / 1000000, | ||
| description: "Meta Llama 3.3 70B Instruct model", | ||
| } | ||
| } | ||
|
|
||
| return models | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional? The constructor doesn't validate if the API key is provided, unlike the base class pattern used by other providers. This could lead to runtime errors when the API key is missing.
Consider adding validation: