Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions src/api/huggingface-models.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { ModelInfo } from "@roo-code/types"
import { z } from "zod"

export interface HuggingFaceModel {
_id: string
id: string
Expand Down Expand Up @@ -52,9 +55,8 @@ const BASE_URL = "https://huggingface.co/api/models"
const CACHE_DURATION = 1000 * 60 * 60 // 1 hour

interface CacheEntry {
data: HuggingFaceModel[]
data: Record<string, ModelInfo>
timestamp: number
status: "success" | "partial" | "error"
}

let cache: CacheEntry | null = null
Expand Down Expand Up @@ -95,7 +97,46 @@ const requestInit: RequestInit = {
mode: "cors",
}

export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
/**
* Parse a HuggingFace model into ModelInfo format
*/
function parseHuggingFaceModel(model: HuggingFaceModel): ModelInfo {
// Extract context window from tokenizer config if available
const contextWindow = model.config.tokenizer_config?.model_max_length || 32768 // Default to 32k

// Determine if model supports images based on pipeline tag
const supportsImages = model.pipeline_tag === "image-text-to-text"

// Create a description from available metadata
const description = [
model.config.model_type ? `Type: ${model.config.model_type}` : null,
model.config.architectures?.length ? `Architecture: ${model.config.architectures[0]}` : null,
model.library_name ? `Library: ${model.library_name}` : null,
model.inferenceProviderMapping?.length
? `Providers: ${model.inferenceProviderMapping.map((p) => p.provider).join(", ")}`
: null,
]
.filter(Boolean)
.join(", ")

const modelInfo: ModelInfo = {
maxTokens: Math.min(contextWindow, 8192), // Conservative default, most models support at least 8k output
contextWindow,
supportsImages,
supportsPromptCache: false, // HuggingFace inference API doesn't support prompt caching
description,
// HuggingFace models through their inference API are generally free
inputPrice: 0,
outputPrice: 0,
}

return modelInfo
}

/**
* Fetch HuggingFace models and return them in ModelInfo format
*/
export async function getHuggingFaceModels(): Promise<Record<string, ModelInfo>> {
const now = Date.now()

// Check cache
Expand All @@ -104,6 +145,8 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
return cache.data
}

const models: Record<string, ModelInfo> = {}

try {
console.log("Fetching Hugging Face models from API...")

Expand All @@ -115,57 +158,49 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {

let textGenModels: HuggingFaceModel[] = []
let imgTextModels: HuggingFaceModel[] = []
let hasErrors = false

// Process text-generation models
if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) {
textGenModels = await textGenResponse.value.json()
} else {
console.error("Failed to fetch text-generation models:", textGenResponse)
hasErrors = true
}

// Process image-text-to-text models
if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) {
imgTextModels = await imgTextResponse.value.json()
} else {
console.error("Failed to fetch image-text-to-text models:", imgTextResponse)
hasErrors = true
}

// Combine and filter models
const allModels = [...textGenModels, ...imgTextModels]
.filter((model) => model.inferenceProviderMapping.length > 0)
.sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase()))
const allModels = [...textGenModels, ...imgTextModels].filter(
(model) => model.inferenceProviderMapping.length > 0,
)

// Convert to ModelInfo format
for (const model of allModels) {
models[model.id] = parseHuggingFaceModel(model)
}

// Update cache
cache = {
data: allModels,
data: models,
timestamp: now,
status: hasErrors ? "partial" : "success",
}

console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`)
return allModels
console.log(`Fetched ${Object.keys(models).length} Hugging Face models`)
return models
} catch (error) {
console.error("Error fetching Hugging Face models:", error)

// Return cached data if available
if (cache) {
console.log("Using stale cached data due to fetch error")
cache.status = "error"
return cache.data
}

// No cache available, return empty array
return []
// No cache available, return empty object
return {}
}
}

export function getCachedModels(): HuggingFaceModel[] | null {
return cache?.data || null
}

export function clearCache(): void {
cache = null
}
4 changes: 4 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { getLiteLLMModels } from "./litellm"
import { GetModelsOptions } from "../../../shared/api"
import { getOllamaModels } from "./ollama"
import { getLMStudioModels } from "./lmstudio"
import { getHuggingFaceModels } from "./huggingface"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

Expand Down Expand Up @@ -78,6 +79,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
break
case "huggingface":
models = await getHuggingFaceModels()
break
default: {
// Ensures router is exhaustively checked if RouterName is a strict union
const exhaustiveCheck: never = provider
Expand Down
69 changes: 38 additions & 31 deletions src/api/providers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,46 @@
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import type { ApiHandlerOptions } from "../../shared/api"
import { type ModelInfo } from "@roo-code/types"

import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"

export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
private client: OpenAI
private options: ApiHandlerOptions
import { RouterProvider } from "./router-provider"

// Default model info for fallback
const huggingFaceDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
}

export class HuggingFaceHandler extends RouterProvider implements SingleCompletionHandler {
constructor(options: ApiHandlerOptions) {
super()
this.options = options
super({
options,
name: "huggingface",
baseURL: "https://router.huggingface.co/v1",
apiKey: options.huggingFaceApiKey,
modelId: options.huggingFaceModelId,
defaultModelId: "meta-llama/Llama-3.3-70B-Instruct",
defaultModelInfo: huggingFaceDefaultModelInfo,
})

if (!this.options.huggingFaceApiKey) {
throw new Error("Hugging Face API key is required")
}

this.client = new OpenAI({
baseURL: "https://router.huggingface.co/v1",
apiKey: this.options.huggingFaceApiKey,
defaultHeaders: DEFAULT_HEADERS,
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
const { id: modelId, info } = await this.fetchModel()
const temperature = this.options.modelTemperature ?? 0.7

const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
Expand All @@ -43,6 +51,11 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
stream_options: { include_usage: true },
}

// Add max_tokens if the model info specifies it
if (info.maxTokens && info.maxTokens > 0) {
params.max_tokens = info.maxTokens
}

const stream = await this.client.chat.completions.create(params)

for await (const chunk of stream) {
Expand All @@ -66,13 +79,20 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
}

async completePrompt(prompt: string): Promise<string> {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
const { id: modelId, info } = await this.fetchModel()

try {
const response = await this.client.chat.completions.create({
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [{ role: "user", content: prompt }],
})
}

// Add max_tokens if the model info specifies it
if (info.maxTokens && info.maxTokens > 0) {
params.max_tokens = info.maxTokens
}

const response = await this.client.chat.completions.create(params)

return response.choices[0]?.message.content || ""
} catch (error) {
Expand All @@ -83,17 +103,4 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
throw error
}
}

override getModel() {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
return {
id: modelId,
info: {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
},
}
}
}
16 changes: 0 additions & 16 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -674,22 +674,6 @@ export const webviewMessageHandler = async (
// TODO: Cache like we do for OpenRouter, etc?
provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels })
break
case "requestHuggingFaceModels":
try {
const { getHuggingFaceModels } = await import("../../api/huggingface-models")
const huggingFaceModelsResponse = await getHuggingFaceModels()
provider.postMessageToWebview({
type: "huggingFaceModels",
huggingFaceModels: huggingFaceModelsResponse.models,
})
} catch (error) {
console.error("Failed to fetch Hugging Face models:", error)
provider.postMessageToWebview({
type: "huggingFaceModels",
huggingFaceModels: [],
})
}
break
case "openImage":
openImage(message.text!, { values: message.values })
break
Expand Down
23 changes: 0 additions & 23 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ export interface ExtensionMessage {
| "ollamaModels"
| "lmStudioModels"
| "vsCodeLmModels"
| "huggingFaceModels"
| "vsCodeLmApiAvailable"
| "updatePrompt"
| "systemPrompt"
Expand Down Expand Up @@ -137,28 +136,6 @@ export interface ExtensionMessage {
ollamaModels?: string[]
lmStudioModels?: string[]
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
huggingFaceModels?: Array<{
_id: string
id: string
inferenceProviderMapping: Array<{
provider: string
providerId: string
status: "live" | "staging" | "error"
task: "conversational"
}>
trendingScore: number
config: {
architectures: string[]
model_type: string
tokenizer_config?: {
chat_template?: string | Array<{ name: string; template: string }>
model_max_length?: number
}
}
tags: string[]
pipeline_tag: "text-generation" | "image-text-to-text"
library_name?: string
}>
mcpServers?: McpServer[]
commits?: GitCommit[]
listApiConfig?: ProviderSettingsEntry[]
Expand Down
1 change: 0 additions & 1 deletion src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ export interface WebviewMessage {
| "requestOllamaModels"
| "requestLmStudioModels"
| "requestVsCodeLmModels"
| "requestHuggingFaceModels"
| "openImage"
| "saveImage"
| "openFile"
Expand Down
12 changes: 11 additions & 1 deletion src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ export type ApiHandlerOptions = Omit<ProviderSettings, "apiProvider">

// RouterName

const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm", "ollama", "lmstudio"] as const
const routerNames = [
"openrouter",
"requesty",
"glama",
"unbound",
"litellm",
"ollama",
"lmstudio",
"huggingface",
] as const

export type RouterName = (typeof routerNames)[number]

Expand Down Expand Up @@ -113,3 +122,4 @@ export type GetModelsOptions =
| { provider: "litellm"; apiKey: string; baseUrl: string }
| { provider: "ollama"; baseUrl?: string }
| { provider: "lmstudio"; baseUrl?: string }
| { provider: "huggingface" }
Loading