Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions huggingface-refactor-plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# HuggingFace Provider Refactoring Plan

## Overview

The HuggingFace provider implementation needs to be refactored to match the established pattern used by other providers that fetch models via network calls (e.g., OpenRouter, Glama, Ollama, etc.).

## Current Implementation Issues

1. **File locations are incorrect:**

- `src/services/huggingface-models.ts` - Should be in `src/api/providers/fetchers/`
- `src/api/huggingface-models.ts` - Unnecessary wrapper, should be removed

2. **Pattern mismatch:**
- Current implementation returns raw HuggingFace model data
- Should return `ModelInfo` records like other providers
- Not integrated with the `modelCache.ts` system
- Provider doesn't use `RouterProvider` base class or `fetchModel` pattern

## Established Pattern (from other providers)

### 1. Fetcher Pattern (`src/api/providers/fetchers/`)

- Fetcher files export a function like `getHuggingFaceModels()` that returns `Record<string, ModelInfo>`
- Fetchers handle API calls and transform raw data to `ModelInfo` format
- Example: `getOpenRouterModels()`, `getGlamaModels()`, `getOllamaModels()`

### 2. Provider Pattern (`src/api/providers/`)

- Providers either:
- Extend `RouterProvider` and use `fetchModel()` (e.g., Glama)
- Implement their own `fetchModel()` pattern (e.g., OpenRouter)
- Use `getModels()` from `modelCache.ts` to fetch and cache models

### 3. Model Cache Integration

- `RouterName` type includes all providers that use the cache
- `modelCache.ts` has a switch statement that calls the appropriate fetcher
- Provides memory and file caching for model lists

## Implementation Steps

### Step 1: Create new fetcher

- Move `src/services/huggingface-models.ts` to `src/api/providers/fetchers/huggingface.ts`
- Transform the fetcher to return `Record<string, ModelInfo>` instead of raw HuggingFace models
- Parse HuggingFace model data to extract:
- `maxTokens`
- `contextWindow`
- `supportsImages` (based on pipeline_tag)
- `description`
- Other relevant `ModelInfo` fields

### Step 2: Update RouterName and modelCache

- Add `"huggingface"` to the `RouterName` type in `src/shared/api.ts`
- Add HuggingFace case to the switch statement in `modelCache.ts`
- Update `GetModelsOptions` type to include HuggingFace

### Step 3: Update HuggingFace provider

- Either extend `RouterProvider` or implement `fetchModel()` pattern
- Use `getModels()` from modelCache to fetch models
- Remove hardcoded model info from `getModel()`

### Step 4: Update webview integration

- Modify `webviewMessageHandler.ts` to use the new pattern
- Instead of importing from `src/api/huggingface-models.ts`, use `getModels()` with provider "huggingface"
- Transform the response to match the expected format for the webview

### Step 5: Cleanup

- Remove `src/api/huggingface-models.ts`
- Remove the old `src/services/huggingface-models.ts`
- Update any other imports

## Benefits of this refactoring

1. **Consistency**: HuggingFace will follow the same pattern as other providers
2. **Caching**: Model lists will be cached in memory and on disk
3. **Maintainability**: Easier to understand and modify when all providers follow the same pattern
4. **Type safety**: Better integration with TypeScript types
26 changes: 26 additions & 0 deletions pr-body.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## Summary

This PR refactors the HuggingFace provider implementation to match the established pattern used by other providers that fetch models via network calls (e.g., OpenRouter, Glama, Ollama).

## Changes

- **Moved fetcher to correct location**: Moved `huggingface-models.ts` from `src/services/` to `src/api/providers/fetchers/huggingface.ts`
- **Updated fetcher to return ModelInfo**: The fetcher now returns `Record<string, ModelInfo>` instead of raw HuggingFace model data, consistent with other providers
- **Integrated with model cache**: Added HuggingFace to `RouterName` type and integrated it with the `modelCache.ts` system for memory and file caching
- **Updated provider to extend RouterProvider**: The HuggingFace provider now extends the `RouterProvider` base class and uses the `fetchModel()` pattern
- **Removed unnecessary wrapper**: Deleted `src/api/huggingface-models.ts` as it's no longer needed
- **Updated webview integration**: Modified `webviewMessageHandler.ts` to use the new pattern with `getModels()` while maintaining backward compatibility

## Benefits

1. **Consistency**: HuggingFace now follows the same pattern as other providers
2. **Caching**: Model lists are now cached in memory and on disk
3. **Maintainability**: Easier to understand and modify when all providers follow the same pattern
4. **Type safety**: Better integration with TypeScript types

## Testing

- ✅ All existing tests pass
- ✅ TypeScript compilation successful
- ✅ Linting checks pass
- ✅ Added HuggingFace to RouterModels mock in webview tests
17 changes: 0 additions & 17 deletions src/api/huggingface-models.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import axios from "axios"
import { ModelInfo } from "@roo-code/types"
import { z } from "zod"

export interface HuggingFaceModel {
_id: string
id: string
Expand Down Expand Up @@ -52,9 +56,8 @@ const BASE_URL = "https://huggingface.co/api/models"
const CACHE_DURATION = 1000 * 60 * 60 // 1 hour

interface CacheEntry {
data: HuggingFaceModel[]
data: Record<string, ModelInfo>
timestamp: number
status: "success" | "partial" | "error"
}

let cache: CacheEntry | null = null
Expand Down Expand Up @@ -95,7 +98,46 @@ const requestInit: RequestInit = {
mode: "cors",
}

export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
/**
* Parse a HuggingFace model into ModelInfo format
*/
function parseHuggingFaceModel(model: HuggingFaceModel): ModelInfo {
// Extract context window from tokenizer config if available
const contextWindow = model.config.tokenizer_config?.model_max_length || 32768 // Default to 32k

// Determine if model supports images based on pipeline tag
const supportsImages = model.pipeline_tag === "image-text-to-text"

// Create a description from available metadata
const description = [
model.config.model_type ? `Type: ${model.config.model_type}` : null,
model.config.architectures?.length ? `Architecture: ${model.config.architectures[0]}` : null,
model.library_name ? `Library: ${model.library_name}` : null,
model.inferenceProviderMapping?.length
? `Providers: ${model.inferenceProviderMapping.map((p) => p.provider).join(", ")}`
: null,
]
.filter(Boolean)
.join(", ")

const modelInfo: ModelInfo = {
maxTokens: Math.min(contextWindow, 8192), // Conservative default, most models support at least 8k output
contextWindow,
supportsImages,
supportsPromptCache: false, // HuggingFace inference API doesn't support prompt caching
description,
// HuggingFace models through their inference API are generally free
inputPrice: 0,
outputPrice: 0,
}

return modelInfo
}

/**
* Fetch HuggingFace models and return them in ModelInfo format
*/
export async function getHuggingFaceModels(): Promise<Record<string, ModelInfo>> {
const now = Date.now()

// Check cache
Expand All @@ -104,6 +146,8 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
return cache.data
}

const models: Record<string, ModelInfo> = {}

try {
console.log("Fetching Hugging Face models from API...")

Expand All @@ -115,57 +159,49 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {

let textGenModels: HuggingFaceModel[] = []
let imgTextModels: HuggingFaceModel[] = []
let hasErrors = false

// Process text-generation models
if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) {
textGenModels = await textGenResponse.value.json()
} else {
console.error("Failed to fetch text-generation models:", textGenResponse)
hasErrors = true
}

// Process image-text-to-text models
if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) {
imgTextModels = await imgTextResponse.value.json()
} else {
console.error("Failed to fetch image-text-to-text models:", imgTextResponse)
hasErrors = true
}

// Combine and filter models
const allModels = [...textGenModels, ...imgTextModels]
.filter((model) => model.inferenceProviderMapping.length > 0)
.sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase()))
const allModels = [...textGenModels, ...imgTextModels].filter(
(model) => model.inferenceProviderMapping.length > 0,
)

// Convert to ModelInfo format
for (const model of allModels) {
models[model.id] = parseHuggingFaceModel(model)
}

// Update cache
cache = {
data: allModels,
data: models,
timestamp: now,
status: hasErrors ? "partial" : "success",
}

console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`)
return allModels
console.log(`Fetched ${Object.keys(models).length} Hugging Face models`)
return models
} catch (error) {
console.error("Error fetching Hugging Face models:", error)

// Return cached data if available
if (cache) {
console.log("Using stale cached data due to fetch error")
cache.status = "error"
return cache.data
}

// No cache available, return empty array
return []
// No cache available, return empty object
return {}
}
}

export function getCachedModels(): HuggingFaceModel[] | null {
return cache?.data || null
}

export function clearCache(): void {
cache = null
}
4 changes: 4 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { getLiteLLMModels } from "./litellm"
import { GetModelsOptions } from "../../../shared/api"
import { getOllamaModels } from "./ollama"
import { getLMStudioModels } from "./lmstudio"
import { getHuggingFaceModels } from "./huggingface"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

Expand Down Expand Up @@ -78,6 +79,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
break
case "huggingface":
models = await getHuggingFaceModels()
break
default: {
// Ensures router is exhaustively checked if RouterName is a strict union
const exhaustiveCheck: never = provider
Expand Down
Loading
Loading