diff --git a/src/api/huggingface-models.ts b/src/api/huggingface-models.ts deleted file mode 100644 index ec1915d0e3d..00000000000 --- a/src/api/huggingface-models.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models" - -export interface HuggingFaceModelsResponse { - models: HuggingFaceModel[] - cached: boolean - timestamp: number -} - -export async function getHuggingFaceModels(): Promise { - const models = await fetchHuggingFaceModels() - - return { - models, - cached: false, // We could enhance this to track if data came from cache - timestamp: Date.now(), - } -} diff --git a/src/services/huggingface-models.ts b/src/api/providers/fetchers/huggingface.ts similarity index 63% rename from src/services/huggingface-models.ts rename to src/api/providers/fetchers/huggingface.ts index 9c0bc406f93..55b17ed9507 100644 --- a/src/services/huggingface-models.ts +++ b/src/api/providers/fetchers/huggingface.ts @@ -1,3 +1,6 @@ +import { ModelInfo } from "@roo-code/types" +import { z } from "zod" + export interface HuggingFaceModel { _id: string id: string @@ -52,9 +55,8 @@ const BASE_URL = "https://huggingface.co/api/models" const CACHE_DURATION = 1000 * 60 * 60 // 1 hour interface CacheEntry { - data: HuggingFaceModel[] + data: Record timestamp: number - status: "success" | "partial" | "error" } let cache: CacheEntry | null = null @@ -95,7 +97,46 @@ const requestInit: RequestInit = { mode: "cors", } -export async function fetchHuggingFaceModels(): Promise { +/** + * Parse a HuggingFace model into ModelInfo format + */ +function parseHuggingFaceModel(model: HuggingFaceModel): ModelInfo { + // Extract context window from tokenizer config if available + const contextWindow = model.config.tokenizer_config?.model_max_length || 32768 // Default to 32k + + // Determine if model supports images based on pipeline tag + const supportsImages = model.pipeline_tag === "image-text-to-text" + + // Create a description from available metadata + const description = [ + model.config.model_type ? `Type: ${model.config.model_type}` : null, + model.config.architectures?.length ? `Architecture: ${model.config.architectures[0]}` : null, + model.library_name ? `Library: ${model.library_name}` : null, + model.inferenceProviderMapping?.length + ? `Providers: ${model.inferenceProviderMapping.map((p) => p.provider).join(", ")}` + : null, + ] + .filter(Boolean) + .join(", ") + + const modelInfo: ModelInfo = { + maxTokens: Math.min(contextWindow, 8192), // Conservative default, most models support at least 8k output + contextWindow, + supportsImages, + supportsPromptCache: false, // HuggingFace inference API doesn't support prompt caching + description, + // HuggingFace models through their inference API are generally free + inputPrice: 0, + outputPrice: 0, + } + + return modelInfo +} + +/** + * Fetch HuggingFace models and return them in ModelInfo format + */ +export async function getHuggingFaceModels(): Promise> { const now = Date.now() // Check cache @@ -104,6 +145,8 @@ export async function fetchHuggingFaceModels(): Promise { return cache.data } + const models: Record = {} + try { console.log("Fetching Hugging Face models from API...") @@ -115,14 +158,12 @@ export async function fetchHuggingFaceModels(): Promise { let textGenModels: HuggingFaceModel[] = [] let imgTextModels: HuggingFaceModel[] = [] - let hasErrors = false // Process text-generation models if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) { textGenModels = await textGenResponse.value.json() } else { console.error("Failed to fetch text-generation models:", textGenResponse) - hasErrors = true } // Process image-text-to-text models @@ -130,42 +171,36 @@ export async function fetchHuggingFaceModels(): Promise { imgTextModels = await imgTextResponse.value.json() } else { console.error("Failed to fetch image-text-to-text models:", imgTextResponse) - hasErrors = true } // Combine and filter models - const allModels = [...textGenModels, ...imgTextModels] - .filter((model) => model.inferenceProviderMapping.length > 0) - .sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase())) + const allModels = [...textGenModels, ...imgTextModels].filter( + (model) => model.inferenceProviderMapping.length > 0, + ) + + // Convert to ModelInfo format + for (const model of allModels) { + models[model.id] = parseHuggingFaceModel(model) + } // Update cache cache = { - data: allModels, + data: models, timestamp: now, - status: hasErrors ? "partial" : "success", } - console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`) - return allModels + console.log(`Fetched ${Object.keys(models).length} Hugging Face models`) + return models } catch (error) { console.error("Error fetching Hugging Face models:", error) // Return cached data if available if (cache) { console.log("Using stale cached data due to fetch error") - cache.status = "error" return cache.data } - // No cache available, return empty array - return [] + // No cache available, return empty object + return {} } } - -export function getCachedModels(): HuggingFaceModel[] | null { - return cache?.data || null -} - -export function clearCache(): void { - cache = null -} diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index fef700268dc..0eb35eedba0 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -17,6 +17,7 @@ import { getLiteLLMModels } from "./litellm" import { GetModelsOptions } from "../../../shared/api" import { getOllamaModels } from "./ollama" import { getLMStudioModels } from "./lmstudio" +import { getHuggingFaceModels } from "./huggingface" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) @@ -78,6 +79,9 @@ export const getModels = async (options: GetModelsOptions): Promise case "lmstudio": models = await getLMStudioModels(options.baseUrl) break + case "huggingface": + models = await getHuggingFaceModels() + break default: { // Ensures router is exhaustively checked if RouterName is a strict union const exhaustiveCheck: never = provider diff --git a/src/api/providers/huggingface.ts b/src/api/providers/huggingface.ts index 913605bd929..3370e764d0e 100644 --- a/src/api/providers/huggingface.ts +++ b/src/api/providers/huggingface.ts @@ -1,30 +1,38 @@ import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" -import type { ApiHandlerOptions } from "../../shared/api" +import { type ModelInfo } from "@roo-code/types" + +import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { DEFAULT_HEADERS } from "./constants" -import { BaseProvider } from "./base-provider" - -export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler { - private client: OpenAI - private options: ApiHandlerOptions +import { RouterProvider } from "./router-provider" + +// Default model info for fallback +const huggingFaceDefaultModelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, +} +export class HuggingFaceHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super() - this.options = options + super({ + options, + name: "huggingface", + baseURL: "https://router.huggingface.co/v1", + apiKey: options.huggingFaceApiKey, + modelId: options.huggingFaceModelId, + defaultModelId: "meta-llama/Llama-3.3-70B-Instruct", + defaultModelInfo: huggingFaceDefaultModelInfo, + }) if (!this.options.huggingFaceApiKey) { throw new Error("Hugging Face API key is required") } - - this.client = new OpenAI({ - baseURL: "https://router.huggingface.co/v1", - apiKey: this.options.huggingFaceApiKey, - defaultHeaders: DEFAULT_HEADERS, - }) } override async *createMessage( @@ -32,7 +40,7 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + const { id: modelId, info } = await this.fetchModel() const temperature = this.options.modelTemperature ?? 0.7 const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { @@ -43,6 +51,11 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion stream_options: { include_usage: true }, } + // Add max_tokens if the model info specifies it + if (info.maxTokens && info.maxTokens > 0) { + params.max_tokens = info.maxTokens + } + const stream = await this.client.chat.completions.create(params) for await (const chunk of stream) { @@ -66,13 +79,20 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion } async completePrompt(prompt: string): Promise { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + const { id: modelId, info } = await this.fetchModel() try { - const response = await this.client.chat.completions.create({ + const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { model: modelId, messages: [{ role: "user", content: prompt }], - }) + } + + // Add max_tokens if the model info specifies it + if (info.maxTokens && info.maxTokens > 0) { + params.max_tokens = info.maxTokens + } + + const response = await this.client.chat.completions.create(params) return response.choices[0]?.message.content || "" } catch (error) { @@ -83,17 +103,4 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion throw error } } - - override getModel() { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" - return { - id: modelId, - info: { - maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: false, - }, - } - } } diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c739c2ade8d..a0e87a49afd 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -674,22 +674,6 @@ export const webviewMessageHandler = async ( // TODO: Cache like we do for OpenRouter, etc? provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels }) break - case "requestHuggingFaceModels": - try { - const { getHuggingFaceModels } = await import("../../api/huggingface-models") - const huggingFaceModelsResponse = await getHuggingFaceModels() - provider.postMessageToWebview({ - type: "huggingFaceModels", - huggingFaceModels: huggingFaceModelsResponse.models, - }) - } catch (error) { - console.error("Failed to fetch Hugging Face models:", error) - provider.postMessageToWebview({ - type: "huggingFaceModels", - huggingFaceModels: [], - }) - } - break case "openImage": openImage(message.text!, { values: message.values }) break diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 000762e317a..2eb08ac7cf6 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -67,7 +67,6 @@ export interface ExtensionMessage { | "ollamaModels" | "lmStudioModels" | "vsCodeLmModels" - | "huggingFaceModels" | "vsCodeLmApiAvailable" | "updatePrompt" | "systemPrompt" @@ -137,28 +136,6 @@ export interface ExtensionMessage { ollamaModels?: string[] lmStudioModels?: string[] vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] - huggingFaceModels?: Array<{ - _id: string - id: string - inferenceProviderMapping: Array<{ - provider: string - providerId: string - status: "live" | "staging" | "error" - task: "conversational" - }> - trendingScore: number - config: { - architectures: string[] - model_type: string - tokenizer_config?: { - chat_template?: string | Array<{ name: string; template: string }> - model_max_length?: number - } - } - tags: string[] - pipeline_tag: "text-generation" | "image-text-to-text" - library_name?: string - }> mcpServers?: McpServer[] commits?: GitCommit[] listApiConfig?: ProviderSettingsEntry[] diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 795e2765222..53b4fa92a7e 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -67,7 +67,6 @@ export interface WebviewMessage { | "requestOllamaModels" | "requestLmStudioModels" | "requestVsCodeLmModels" - | "requestHuggingFaceModels" | "openImage" | "saveImage" | "openFile" diff --git a/src/shared/api.ts b/src/shared/api.ts index 8cbfc721336..705e5d832fd 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -11,7 +11,16 @@ export type ApiHandlerOptions = Omit // RouterName -const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm", "ollama", "lmstudio"] as const +const routerNames = [ + "openrouter", + "requesty", + "glama", + "unbound", + "litellm", + "ollama", + "lmstudio", + "huggingface", +] as const export type RouterName = (typeof routerNames)[number] @@ -113,3 +122,4 @@ export type GetModelsOptions = | { provider: "litellm"; apiKey: string; baseUrl: string } | { provider: "ollama"; baseUrl?: string } | { provider: "lmstudio"; baseUrl?: string } + | { provider: "huggingface" } diff --git a/webview-ui/src/components/settings/providers/HuggingFace.tsx b/webview-ui/src/components/settings/providers/HuggingFace.tsx index d4195492dd7..ee2dc56b53f 100644 --- a/webview-ui/src/components/settings/providers/HuggingFace.tsx +++ b/webview-ui/src/components/settings/providers/HuggingFace.tsx @@ -62,7 +62,7 @@ export const HuggingFace = ({ apiConfiguration, setApiConfigurationField }: Hugg // Fetch models when component mounts useEffect(() => { setLoading(true) - vscode.postMessage({ type: "requestHuggingFaceModels" }) + vscode.postMessage({ type: "requestRouterModels" }) }, []) // Handle messages from extension @@ -70,8 +70,46 @@ export const HuggingFace = ({ apiConfiguration, setApiConfigurationField }: Hugg const message: ExtensionMessage = event.data switch (message.type) { - case "huggingFaceModels": - setModels(message.huggingFaceModels || []) + case "routerModels": + // Extract HuggingFace models from routerModels + if (message.routerModels?.huggingface) { + // Convert from ModelRecord format to HuggingFaceModel array format + const modelArray = Object.entries(message.routerModels.huggingface).map(([id, info]) => ({ + id, + _id: id, + inferenceProviderMapping: [ + { + provider: "huggingface", + providerId: id, + status: "live" as const, + task: "conversational" as const, + }, + ], + trendingScore: 0, + config: { + architectures: [], + model_type: + info.description + ?.split(", ") + .find((part: string) => part.startsWith("Type: ")) + ?.replace("Type: ", "") || "", + tokenizer_config: { + model_max_length: info.contextWindow, + }, + }, + tags: [], + pipeline_tag: info.supportsImages + ? ("image-text-to-text" as const) + : ("text-generation" as const), + library_name: info.description + ?.split(", ") + .find((part: string) => part.startsWith("Library: ")) + ?.replace("Library: ", ""), + })) + setModels(modelArray) + } else { + setModels([]) + } setLoading(false) break } diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index 3a60c27f8ad..01452995bc1 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -38,6 +38,7 @@ describe("Model Validation Functions", () => { litellm: {}, ollama: {}, lmstudio: {}, + huggingface: {}, } const allowAllOrganization: OrganizationAllowList = {