Skip to content

Commit e0e5416

Browse files
committed
refactor: align HuggingFace provider with established pattern
- Move huggingface-models.ts from src/services/ to src/api/providers/fetchers/huggingface.ts - Update fetcher to return ModelInfo records instead of raw HuggingFace models - Add HuggingFace to RouterName type and integrate with modelCache.ts - Update HuggingFace provider to extend RouterProvider base class - Remove unnecessary src/api/huggingface-models.ts wrapper - Update webviewMessageHandler to use the new pattern with getModels() - Maintain backward compatibility with webview by transforming ModelInfo to expected format
1 parent d62a260 commit e0e5416

File tree

7 files changed

+239
-77
lines changed

7 files changed

+239
-77
lines changed

huggingface-refactor-plan.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# HuggingFace Provider Refactoring Plan
2+
3+
## Overview
4+
5+
The HuggingFace provider implementation needs to be refactored to match the established pattern used by other providers that fetch models via network calls (e.g., OpenRouter, Glama, Ollama, etc.).
6+
7+
## Current Implementation Issues
8+
9+
1. **File locations are incorrect:**
10+
11+
- `src/services/huggingface-models.ts` - Should be in `src/api/providers/fetchers/`
12+
- `src/api/huggingface-models.ts` - Unnecessary wrapper, should be removed
13+
14+
2. **Pattern mismatch:**
15+
- Current implementation returns raw HuggingFace model data
16+
- Should return `ModelInfo` records like other providers
17+
- Not integrated with the `modelCache.ts` system
18+
- Provider doesn't use `RouterProvider` base class or `fetchModel` pattern
19+
20+
## Established Pattern (from other providers)
21+
22+
### 1. Fetcher Pattern (`src/api/providers/fetchers/`)
23+
24+
- Fetcher files export a function like `getHuggingFaceModels()` that returns `Record<string, ModelInfo>`
25+
- Fetchers handle API calls and transform raw data to `ModelInfo` format
26+
- Example: `getOpenRouterModels()`, `getGlamaModels()`, `getOllamaModels()`
27+
28+
### 2. Provider Pattern (`src/api/providers/`)
29+
30+
- Providers either:
31+
- Extend `RouterProvider` and use `fetchModel()` (e.g., Glama)
32+
- Implement their own `fetchModel()` pattern (e.g., OpenRouter)
33+
- Use `getModels()` from `modelCache.ts` to fetch and cache models
34+
35+
### 3. Model Cache Integration
36+
37+
- `RouterName` type includes all providers that use the cache
38+
- `modelCache.ts` has a switch statement that calls the appropriate fetcher
39+
- Provides memory and file caching for model lists
40+
41+
## Implementation Steps
42+
43+
### Step 1: Create new fetcher
44+
45+
- Move `src/services/huggingface-models.ts` to `src/api/providers/fetchers/huggingface.ts`
46+
- Transform the fetcher to return `Record<string, ModelInfo>` instead of raw HuggingFace models
47+
- Parse HuggingFace model data to extract:
48+
- `maxTokens`
49+
- `contextWindow`
50+
- `supportsImages` (based on pipeline_tag)
51+
- `description`
52+
- Other relevant `ModelInfo` fields
53+
54+
### Step 2: Update RouterName and modelCache
55+
56+
- Add `"huggingface"` to the `RouterName` type in `src/shared/api.ts`
57+
- Add HuggingFace case to the switch statement in `modelCache.ts`
58+
- Update `GetModelsOptions` type to include HuggingFace
59+
60+
### Step 3: Update HuggingFace provider
61+
62+
- Either extend `RouterProvider` or implement `fetchModel()` pattern
63+
- Use `getModels()` from modelCache to fetch models
64+
- Remove hardcoded model info from `getModel()`
65+
66+
### Step 4: Update webview integration
67+
68+
- Modify `webviewMessageHandler.ts` to use the new pattern
69+
- Instead of importing from `src/api/huggingface-models.ts`, use `getModels()` with provider "huggingface"
70+
- Transform the response to match the expected format for the webview
71+
72+
### Step 5: Cleanup
73+
74+
- Remove `src/api/huggingface-models.ts`
75+
- Remove the old `src/services/huggingface-models.ts`
76+
- Update any other imports
77+
78+
## Benefits of this refactoring
79+
80+
1. **Consistency**: HuggingFace will follow the same pattern as other providers
81+
2. **Caching**: Model lists will be cached in memory and on disk
82+
3. **Maintainability**: Easier to understand and modify when all providers follow the same pattern
83+
4. **Type safety**: Better integration with TypeScript types

src/api/huggingface-models.ts

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import axios from "axios"
2+
import { ModelInfo } from "@roo-code/types"
3+
import { z } from "zod"
4+
15
export interface HuggingFaceModel {
26
_id: string
37
id: string
@@ -52,9 +56,8 @@ const BASE_URL = "https://huggingface.co/api/models"
5256
const CACHE_DURATION = 1000 * 60 * 60 // 1 hour
5357

5458
interface CacheEntry {
55-
data: HuggingFaceModel[]
59+
data: Record<string, ModelInfo>
5660
timestamp: number
57-
status: "success" | "partial" | "error"
5861
}
5962

6063
let cache: CacheEntry | null = null
@@ -95,7 +98,46 @@ const requestInit: RequestInit = {
9598
mode: "cors",
9699
}
97100

98-
export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
101+
/**
102+
* Parse a HuggingFace model into ModelInfo format
103+
*/
104+
function parseHuggingFaceModel(model: HuggingFaceModel): ModelInfo {
105+
// Extract context window from tokenizer config if available
106+
const contextWindow = model.config.tokenizer_config?.model_max_length || 32768 // Default to 32k
107+
108+
// Determine if model supports images based on pipeline tag
109+
const supportsImages = model.pipeline_tag === "image-text-to-text"
110+
111+
// Create a description from available metadata
112+
const description = [
113+
model.config.model_type ? `Type: ${model.config.model_type}` : null,
114+
model.config.architectures?.length ? `Architecture: ${model.config.architectures[0]}` : null,
115+
model.library_name ? `Library: ${model.library_name}` : null,
116+
model.inferenceProviderMapping?.length
117+
? `Providers: ${model.inferenceProviderMapping.map((p) => p.provider).join(", ")}`
118+
: null,
119+
]
120+
.filter(Boolean)
121+
.join(", ")
122+
123+
const modelInfo: ModelInfo = {
124+
maxTokens: Math.min(contextWindow, 8192), // Conservative default, most models support at least 8k output
125+
contextWindow,
126+
supportsImages,
127+
supportsPromptCache: false, // HuggingFace inference API doesn't support prompt caching
128+
description,
129+
// HuggingFace models through their inference API are generally free
130+
inputPrice: 0,
131+
outputPrice: 0,
132+
}
133+
134+
return modelInfo
135+
}
136+
137+
/**
138+
* Fetch HuggingFace models and return them in ModelInfo format
139+
*/
140+
export async function getHuggingFaceModels(): Promise<Record<string, ModelInfo>> {
99141
const now = Date.now()
100142

101143
// Check cache
@@ -104,6 +146,8 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
104146
return cache.data
105147
}
106148

149+
const models: Record<string, ModelInfo> = {}
150+
107151
try {
108152
console.log("Fetching Hugging Face models from API...")
109153

@@ -115,57 +159,49 @@ export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
115159

116160
let textGenModels: HuggingFaceModel[] = []
117161
let imgTextModels: HuggingFaceModel[] = []
118-
let hasErrors = false
119162

120163
// Process text-generation models
121164
if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) {
122165
textGenModels = await textGenResponse.value.json()
123166
} else {
124167
console.error("Failed to fetch text-generation models:", textGenResponse)
125-
hasErrors = true
126168
}
127169

128170
// Process image-text-to-text models
129171
if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) {
130172
imgTextModels = await imgTextResponse.value.json()
131173
} else {
132174
console.error("Failed to fetch image-text-to-text models:", imgTextResponse)
133-
hasErrors = true
134175
}
135176

136177
// Combine and filter models
137-
const allModels = [...textGenModels, ...imgTextModels]
138-
.filter((model) => model.inferenceProviderMapping.length > 0)
139-
.sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase()))
178+
const allModels = [...textGenModels, ...imgTextModels].filter(
179+
(model) => model.inferenceProviderMapping.length > 0,
180+
)
181+
182+
// Convert to ModelInfo format
183+
for (const model of allModels) {
184+
models[model.id] = parseHuggingFaceModel(model)
185+
}
140186

141187
// Update cache
142188
cache = {
143-
data: allModels,
189+
data: models,
144190
timestamp: now,
145-
status: hasErrors ? "partial" : "success",
146191
}
147192

148-
console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`)
149-
return allModels
193+
console.log(`Fetched ${Object.keys(models).length} Hugging Face models`)
194+
return models
150195
} catch (error) {
151196
console.error("Error fetching Hugging Face models:", error)
152197

153198
// Return cached data if available
154199
if (cache) {
155200
console.log("Using stale cached data due to fetch error")
156-
cache.status = "error"
157201
return cache.data
158202
}
159203

160-
// No cache available, return empty array
161-
return []
204+
// No cache available, return empty object
205+
return {}
162206
}
163207
}
164-
165-
export function getCachedModels(): HuggingFaceModel[] | null {
166-
return cache?.data || null
167-
}
168-
169-
export function clearCache(): void {
170-
cache = null
171-
}

src/api/providers/fetchers/modelCache.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import { getLiteLLMModels } from "./litellm"
1717
import { GetModelsOptions } from "../../../shared/api"
1818
import { getOllamaModels } from "./ollama"
1919
import { getLMStudioModels } from "./lmstudio"
20+
import { getHuggingFaceModels } from "./huggingface"
2021

2122
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2223

@@ -78,6 +79,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
7879
case "lmstudio":
7980
models = await getLMStudioModels(options.baseUrl)
8081
break
82+
case "huggingface":
83+
models = await getHuggingFaceModels()
84+
break
8185
default: {
8286
// Ensures router is exhaustively checked if RouterName is a strict union
8387
const exhaustiveCheck: never = provider

src/api/providers/huggingface.ts

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,46 @@
11
import OpenAI from "openai"
22
import { Anthropic } from "@anthropic-ai/sdk"
33

4-
import type { ApiHandlerOptions } from "../../shared/api"
4+
import { type ModelInfo } from "@roo-code/types"
5+
6+
import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
57
import { ApiStream } from "../transform/stream"
68
import { convertToOpenAiMessages } from "../transform/openai-format"
79
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
810
import { DEFAULT_HEADERS } from "./constants"
9-
import { BaseProvider } from "./base-provider"
10-
11-
export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
12-
private client: OpenAI
13-
private options: ApiHandlerOptions
11+
import { RouterProvider } from "./router-provider"
12+
13+
// Default model info for fallback
14+
const huggingFaceDefaultModelInfo: ModelInfo = {
15+
maxTokens: 8192,
16+
contextWindow: 131072,
17+
supportsImages: false,
18+
supportsPromptCache: false,
19+
}
1420

21+
export class HuggingFaceHandler extends RouterProvider implements SingleCompletionHandler {
1522
constructor(options: ApiHandlerOptions) {
16-
super()
17-
this.options = options
23+
super({
24+
options,
25+
name: "huggingface",
26+
baseURL: "https://router.huggingface.co/v1",
27+
apiKey: options.huggingFaceApiKey,
28+
modelId: options.huggingFaceModelId,
29+
defaultModelId: "meta-llama/Llama-3.3-70B-Instruct",
30+
defaultModelInfo: huggingFaceDefaultModelInfo,
31+
})
1832

1933
if (!this.options.huggingFaceApiKey) {
2034
throw new Error("Hugging Face API key is required")
2135
}
22-
23-
this.client = new OpenAI({
24-
baseURL: "https://router.huggingface.co/v1",
25-
apiKey: this.options.huggingFaceApiKey,
26-
defaultHeaders: DEFAULT_HEADERS,
27-
})
2836
}
2937

3038
override async *createMessage(
3139
systemPrompt: string,
3240
messages: Anthropic.Messages.MessageParam[],
3341
metadata?: ApiHandlerCreateMessageMetadata,
3442
): ApiStream {
35-
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
43+
const { id: modelId, info } = await this.fetchModel()
3644
const temperature = this.options.modelTemperature ?? 0.7
3745

3846
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
@@ -43,6 +51,11 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
4351
stream_options: { include_usage: true },
4452
}
4553

54+
// Add max_tokens if the model info specifies it
55+
if (info.maxTokens && info.maxTokens > 0) {
56+
params.max_tokens = info.maxTokens
57+
}
58+
4659
const stream = await this.client.chat.completions.create(params)
4760

4861
for await (const chunk of stream) {
@@ -66,13 +79,20 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
6679
}
6780

6881
async completePrompt(prompt: string): Promise<string> {
69-
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
82+
const { id: modelId, info } = await this.fetchModel()
7083

7184
try {
72-
const response = await this.client.chat.completions.create({
85+
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
7386
model: modelId,
7487
messages: [{ role: "user", content: prompt }],
75-
})
88+
}
89+
90+
// Add max_tokens if the model info specifies it
91+
if (info.maxTokens && info.maxTokens > 0) {
92+
params.max_tokens = info.maxTokens
93+
}
94+
95+
const response = await this.client.chat.completions.create(params)
7696

7797
return response.choices[0]?.message.content || ""
7898
} catch (error) {
@@ -83,17 +103,4 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
83103
throw error
84104
}
85105
}
86-
87-
override getModel() {
88-
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
89-
return {
90-
id: modelId,
91-
info: {
92-
maxTokens: 8192,
93-
contextWindow: 131072,
94-
supportsImages: false,
95-
supportsPromptCache: false,
96-
},
97-
}
98-
}
99106
}

0 commit comments

Comments
 (0)