|
| 1 | +import axios from "axios" |
| 2 | +import { z } from "zod" |
| 3 | + |
| 4 | +import type { ModelInfo } from "@roo-code/types" |
| 5 | + |
| 6 | +import { parseApiPrice } from "../../../shared/cost" |
| 7 | + |
| 8 | +/** |
| 9 | + * DeepInfra Model Schema |
| 10 | + */ |
| 11 | +const deepInfraModelSchema = z.object({ |
| 12 | + model_name: z.string(), |
| 13 | + type: z.string().optional(), |
| 14 | + max_tokens: z.number().optional(), |
| 15 | + context_length: z.number().optional(), |
| 16 | + pricing: z |
| 17 | + .object({ |
| 18 | + input: z.number().optional(), |
| 19 | + output: z.number().optional(), |
| 20 | + cached_input: z.number().optional(), |
| 21 | + }) |
| 22 | + .optional(), |
| 23 | + description: z.string().optional(), |
| 24 | + capabilities: z.array(z.string()).optional(), |
| 25 | +}) |
| 26 | + |
| 27 | +type DeepInfraModel = z.infer<typeof deepInfraModelSchema> |
| 28 | + |
| 29 | +/** |
| 30 | + * DeepInfra Models Response Schema |
| 31 | + */ |
| 32 | +const deepInfraModelsResponseSchema = z.array(deepInfraModelSchema) |
| 33 | + |
| 34 | +type DeepInfraModelsResponse = z.infer<typeof deepInfraModelsResponseSchema> |
| 35 | + |
| 36 | +/** |
| 37 | + * Fetch models from DeepInfra API |
| 38 | + */ |
| 39 | +export async function getDeepInfraModels(apiKey?: string): Promise<Record<string, ModelInfo>> { |
| 40 | + const models: Record<string, ModelInfo> = {} |
| 41 | + const baseURL = "https://api.deepinfra.com/v1/openai" |
| 42 | + |
| 43 | + try { |
| 44 | + // DeepInfra requires authentication to fetch models |
| 45 | + if (!apiKey) { |
| 46 | + console.log("DeepInfra API key not provided, returning empty models") |
| 47 | + return models |
| 48 | + } |
| 49 | + |
| 50 | + const response = await axios.get<DeepInfraModelsResponse>(`${baseURL}/models`, { |
| 51 | + headers: { |
| 52 | + Authorization: `Bearer ${apiKey}`, |
| 53 | + }, |
| 54 | + }) |
| 55 | + |
| 56 | + const result = deepInfraModelsResponseSchema.safeParse(response.data) |
| 57 | + const data = result.success ? result.data : response.data |
| 58 | + |
| 59 | + if (!result.success) { |
| 60 | + console.error("DeepInfra models response is invalid", result.error.format()) |
| 61 | + } |
| 62 | + |
| 63 | + // Process each model from the response |
| 64 | + for (const model of data) { |
| 65 | + // Skip non-text models |
| 66 | + if (model.type && !["text", "chat", "instruct"].includes(model.type)) { |
| 67 | + continue |
| 68 | + } |
| 69 | + |
| 70 | + const modelInfo: ModelInfo = { |
| 71 | + maxTokens: model.max_tokens || 4096, |
| 72 | + contextWindow: model.context_length || 32768, |
| 73 | + supportsImages: model.capabilities?.includes("vision") || false, |
| 74 | + supportsPromptCache: true, // DeepInfra supports prompt caching |
| 75 | + inputPrice: model.pricing?.input ? model.pricing.input / 1000000 : 0.15, // Convert from per million to per token |
| 76 | + outputPrice: model.pricing?.output ? model.pricing.output / 1000000 : 0.6, |
| 77 | + cacheReadsPrice: model.pricing?.cached_input ? model.pricing.cached_input / 1000000 : undefined, |
| 78 | + description: model.description, |
| 79 | + } |
| 80 | + |
| 81 | + models[model.model_name] = modelInfo |
| 82 | + } |
| 83 | + |
| 84 | + // If the API doesn't return models, provide some default popular models |
| 85 | + if (Object.keys(models).length === 0) { |
| 86 | + console.log("No models returned from DeepInfra API, using default models") |
| 87 | + |
| 88 | + // Default popular models on DeepInfra |
| 89 | + models["meta-llama/Llama-3.3-70B-Instruct"] = { |
| 90 | + maxTokens: 8192, |
| 91 | + contextWindow: 131072, |
| 92 | + supportsImages: false, |
| 93 | + supportsPromptCache: true, |
| 94 | + inputPrice: 0.35 / 1000000, |
| 95 | + outputPrice: 0.4 / 1000000, |
| 96 | + cacheReadsPrice: 0.175 / 1000000, |
| 97 | + description: "Meta Llama 3.3 70B Instruct model", |
| 98 | + } |
| 99 | + |
| 100 | + models["meta-llama/Llama-3.1-8B-Instruct"] = { |
| 101 | + maxTokens: 4096, |
| 102 | + contextWindow: 131072, |
| 103 | + supportsImages: false, |
| 104 | + supportsPromptCache: true, |
| 105 | + inputPrice: 0.06 / 1000000, |
| 106 | + outputPrice: 0.06 / 1000000, |
| 107 | + cacheReadsPrice: 0.03 / 1000000, |
| 108 | + description: "Meta Llama 3.1 8B Instruct model", |
| 109 | + } |
| 110 | + |
| 111 | + models["Qwen/Qwen2.5-72B-Instruct"] = { |
| 112 | + maxTokens: 8192, |
| 113 | + contextWindow: 131072, |
| 114 | + supportsImages: false, |
| 115 | + supportsPromptCache: true, |
| 116 | + inputPrice: 0.35 / 1000000, |
| 117 | + outputPrice: 0.4 / 1000000, |
| 118 | + cacheReadsPrice: 0.175 / 1000000, |
| 119 | + description: "Qwen 2.5 72B Instruct model", |
| 120 | + } |
| 121 | + |
| 122 | + models["mistralai/Mixtral-8x7B-Instruct-v0.1"] = { |
| 123 | + maxTokens: 4096, |
| 124 | + contextWindow: 32768, |
| 125 | + supportsImages: false, |
| 126 | + supportsPromptCache: true, |
| 127 | + inputPrice: 0.24 / 1000000, |
| 128 | + outputPrice: 0.24 / 1000000, |
| 129 | + cacheReadsPrice: 0.12 / 1000000, |
| 130 | + description: "Mistral Mixtral 8x7B Instruct model", |
| 131 | + } |
| 132 | + } |
| 133 | + } catch (error) { |
| 134 | + console.error(`Error fetching DeepInfra models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) |
| 135 | + |
| 136 | + // Return default models on error |
| 137 | + models["meta-llama/Llama-3.3-70B-Instruct"] = { |
| 138 | + maxTokens: 8192, |
| 139 | + contextWindow: 131072, |
| 140 | + supportsImages: false, |
| 141 | + supportsPromptCache: true, |
| 142 | + inputPrice: 0.35 / 1000000, |
| 143 | + outputPrice: 0.4 / 1000000, |
| 144 | + cacheReadsPrice: 0.175 / 1000000, |
| 145 | + description: "Meta Llama 3.3 70B Instruct model", |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + return models |
| 150 | +} |
0 commit comments