Skip to content

Commit bcad858

Browse files
mrubensTGlide
andauthored
basic hugging face provider (#6134)
* basic hugging face provider * fetch hf models and providers * save provider to config * Update translations --------- Co-authored-by: Thomas G. Lopes <[email protected]>
1 parent 0323256 commit bcad858

33 files changed

+970
-0
lines changed

packages/types/src/provider-settings.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export const providerNames = [
3232
"groq",
3333
"chutes",
3434
"litellm",
35+
"huggingface",
3536
] as const
3637

3738
export const providerNamesSchema = z.enum(providerNames)
@@ -219,6 +220,12 @@ const groqSchema = apiModelIdProviderModelSchema.extend({
219220
groqApiKey: z.string().optional(),
220221
})
221222

223+
const huggingFaceSchema = baseProviderSettingsSchema.extend({
224+
huggingFaceApiKey: z.string().optional(),
225+
huggingFaceModelId: z.string().optional(),
226+
huggingFaceInferenceProvider: z.string().optional(),
227+
})
228+
222229
const chutesSchema = apiModelIdProviderModelSchema.extend({
223230
chutesApiKey: z.string().optional(),
224231
})
@@ -256,6 +263,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
256263
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
257264
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
258265
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
266+
huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })),
259267
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
260268
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
261269
defaultSchema,
@@ -285,6 +293,7 @@ export const providerSettingsSchema = z.object({
285293
...fakeAiSchema.shape,
286294
...xaiSchema.shape,
287295
...groqSchema.shape,
296+
...huggingFaceSchema.shape,
288297
...chutesSchema.shape,
289298
...litellmSchema.shape,
290299
...codebaseIndexProviderSchema.shape,
@@ -304,6 +313,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
304313
"unboundModelId",
305314
"requestyModelId",
306315
"litellmModelId",
316+
"huggingFaceModelId",
307317
]
308318

309319
export const getModelId = (settings: ProviderSettings): string | undefined => {

src/api/huggingface-models.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models"
2+
3+
export interface HuggingFaceModelsResponse {
4+
models: HuggingFaceModel[]
5+
cached: boolean
6+
timestamp: number
7+
}
8+
9+
export async function getHuggingFaceModels(): Promise<HuggingFaceModelsResponse> {
10+
const models = await fetchHuggingFaceModels()
11+
12+
return {
13+
models,
14+
cached: false, // We could enhance this to track if data came from cache
15+
timestamp: Date.now(),
16+
}
17+
}

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import {
2626
FakeAIHandler,
2727
XAIHandler,
2828
GroqHandler,
29+
HuggingFaceHandler,
2930
ChutesHandler,
3031
LiteLLMHandler,
3132
ClaudeCodeHandler,
@@ -108,6 +109,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
108109
return new XAIHandler(options)
109110
case "groq":
110111
return new GroqHandler(options)
112+
case "huggingface":
113+
return new HuggingFaceHandler(options)
111114
case "chutes":
112115
return new ChutesHandler(options)
113116
case "litellm":

src/api/providers/huggingface.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import OpenAI from "openai"
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
4+
import type { ApiHandlerOptions } from "../../shared/api"
5+
import { ApiStream } from "../transform/stream"
6+
import { convertToOpenAiMessages } from "../transform/openai-format"
7+
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
8+
import { DEFAULT_HEADERS } from "./constants"
9+
import { BaseProvider } from "./base-provider"
10+
11+
export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
12+
private client: OpenAI
13+
private options: ApiHandlerOptions
14+
15+
constructor(options: ApiHandlerOptions) {
16+
super()
17+
this.options = options
18+
19+
if (!this.options.huggingFaceApiKey) {
20+
throw new Error("Hugging Face API key is required")
21+
}
22+
23+
this.client = new OpenAI({
24+
baseURL: "https://router.huggingface.co/v1",
25+
apiKey: this.options.huggingFaceApiKey,
26+
defaultHeaders: DEFAULT_HEADERS,
27+
})
28+
}
29+
30+
override async *createMessage(
31+
systemPrompt: string,
32+
messages: Anthropic.Messages.MessageParam[],
33+
metadata?: ApiHandlerCreateMessageMetadata,
34+
): ApiStream {
35+
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
36+
const temperature = this.options.modelTemperature ?? 0.7
37+
38+
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
39+
model: modelId,
40+
temperature,
41+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
42+
stream: true,
43+
stream_options: { include_usage: true },
44+
}
45+
46+
const stream = await this.client.chat.completions.create(params)
47+
48+
for await (const chunk of stream) {
49+
const delta = chunk.choices[0]?.delta
50+
51+
if (delta?.content) {
52+
yield {
53+
type: "text",
54+
text: delta.content,
55+
}
56+
}
57+
58+
if (chunk.usage) {
59+
yield {
60+
type: "usage",
61+
inputTokens: chunk.usage.prompt_tokens || 0,
62+
outputTokens: chunk.usage.completion_tokens || 0,
63+
}
64+
}
65+
}
66+
}
67+
68+
async completePrompt(prompt: string): Promise<string> {
69+
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
70+
71+
try {
72+
const response = await this.client.chat.completions.create({
73+
model: modelId,
74+
messages: [{ role: "user", content: prompt }],
75+
})
76+
77+
return response.choices[0]?.message.content || ""
78+
} catch (error) {
79+
if (error instanceof Error) {
80+
throw new Error(`Hugging Face completion error: ${error.message}`)
81+
}
82+
83+
throw error
84+
}
85+
}
86+
87+
override getModel() {
88+
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
89+
return {
90+
id: modelId,
91+
info: {
92+
maxTokens: 8192,
93+
contextWindow: 131072,
94+
supportsImages: false,
95+
supportsPromptCache: false,
96+
},
97+
}
98+
}
99+
}

src/api/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export { FakeAIHandler } from "./fake-ai"
99
export { GeminiHandler } from "./gemini"
1010
export { GlamaHandler } from "./glama"
1111
export { GroqHandler } from "./groq"
12+
export { HuggingFaceHandler } from "./huggingface"
1213
export { HumanRelayHandler } from "./human-relay"
1314
export { LiteLLMHandler } from "./lite-llm"
1415
export { LmStudioHandler } from "./lm-studio"

src/core/webview/webviewMessageHandler.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,22 @@ export const webviewMessageHandler = async (
674674
// TODO: Cache like we do for OpenRouter, etc?
675675
provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels })
676676
break
677+
case "requestHuggingFaceModels":
678+
try {
679+
const { getHuggingFaceModels } = await import("../../api/huggingface-models")
680+
const huggingFaceModelsResponse = await getHuggingFaceModels()
681+
provider.postMessageToWebview({
682+
type: "huggingFaceModels",
683+
huggingFaceModels: huggingFaceModelsResponse.models,
684+
})
685+
} catch (error) {
686+
console.error("Failed to fetch Hugging Face models:", error)
687+
provider.postMessageToWebview({
688+
type: "huggingFaceModels",
689+
huggingFaceModels: [],
690+
})
691+
}
692+
break
677693
case "openImage":
678694
openImage(message.text!, { values: message.values })
679695
break

src/services/huggingface-models.ts

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
export interface HuggingFaceModel {
2+
_id: string
3+
id: string
4+
inferenceProviderMapping: InferenceProviderMapping[]
5+
trendingScore: number
6+
config: ModelConfig
7+
tags: string[]
8+
pipeline_tag: "text-generation" | "image-text-to-text"
9+
library_name?: string
10+
}
11+
12+
export interface InferenceProviderMapping {
13+
provider: string
14+
providerId: string
15+
status: "live" | "staging" | "error"
16+
task: "conversational"
17+
}
18+
19+
export interface ModelConfig {
20+
architectures: string[]
21+
model_type: string
22+
tokenizer_config?: {
23+
chat_template?: string | Array<{ name: string; template: string }>
24+
model_max_length?: number
25+
}
26+
}
27+
28+
interface HuggingFaceApiParams {
29+
pipeline_tag?: "text-generation" | "image-text-to-text"
30+
filter: string
31+
inference_provider: string
32+
limit: number
33+
expand: string[]
34+
}
35+
36+
const DEFAULT_PARAMS: HuggingFaceApiParams = {
37+
filter: "conversational",
38+
inference_provider: "all",
39+
limit: 100,
40+
expand: [
41+
"inferenceProviderMapping",
42+
"config",
43+
"library_name",
44+
"pipeline_tag",
45+
"tags",
46+
"mask_token",
47+
"trendingScore",
48+
],
49+
}
50+
51+
const BASE_URL = "https://huggingface.co/api/models"
52+
const CACHE_DURATION = 1000 * 60 * 60 // 1 hour
53+
54+
interface CacheEntry {
55+
data: HuggingFaceModel[]
56+
timestamp: number
57+
status: "success" | "partial" | "error"
58+
}
59+
60+
let cache: CacheEntry | null = null
61+
62+
function buildApiUrl(params: HuggingFaceApiParams): string {
63+
const url = new URL(BASE_URL)
64+
65+
// Add simple params
66+
Object.entries(params).forEach(([key, value]) => {
67+
if (!Array.isArray(value)) {
68+
url.searchParams.append(key, String(value))
69+
}
70+
})
71+
72+
// Handle array params specially
73+
params.expand.forEach((item) => {
74+
url.searchParams.append("expand[]", item)
75+
})
76+
77+
return url.toString()
78+
}
79+
80+
const headers: HeadersInit = {
81+
"Upgrade-Insecure-Requests": "1",
82+
"Sec-Fetch-Dest": "document",
83+
"Sec-Fetch-Mode": "navigate",
84+
"Sec-Fetch-Site": "none",
85+
"Sec-Fetch-User": "?1",
86+
Priority: "u=0, i",
87+
Pragma: "no-cache",
88+
"Cache-Control": "no-cache",
89+
}
90+
91+
const requestInit: RequestInit = {
92+
credentials: "include",
93+
headers,
94+
method: "GET",
95+
mode: "cors",
96+
}
97+
98+
export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
99+
const now = Date.now()
100+
101+
// Check cache
102+
if (cache && now - cache.timestamp < CACHE_DURATION) {
103+
console.log("Using cached Hugging Face models")
104+
return cache.data
105+
}
106+
107+
try {
108+
console.log("Fetching Hugging Face models from API...")
109+
110+
// Fetch both text-generation and image-text-to-text models in parallel
111+
const [textGenResponse, imgTextResponse] = await Promise.allSettled([
112+
fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "text-generation" }), requestInit),
113+
fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "image-text-to-text" }), requestInit),
114+
])
115+
116+
let textGenModels: HuggingFaceModel[] = []
117+
let imgTextModels: HuggingFaceModel[] = []
118+
let hasErrors = false
119+
120+
// Process text-generation models
121+
if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) {
122+
textGenModels = await textGenResponse.value.json()
123+
} else {
124+
console.error("Failed to fetch text-generation models:", textGenResponse)
125+
hasErrors = true
126+
}
127+
128+
// Process image-text-to-text models
129+
if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) {
130+
imgTextModels = await imgTextResponse.value.json()
131+
} else {
132+
console.error("Failed to fetch image-text-to-text models:", imgTextResponse)
133+
hasErrors = true
134+
}
135+
136+
// Combine and filter models
137+
const allModels = [...textGenModels, ...imgTextModels]
138+
.filter((model) => model.inferenceProviderMapping.length > 0)
139+
.sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase()))
140+
141+
// Update cache
142+
cache = {
143+
data: allModels,
144+
timestamp: now,
145+
status: hasErrors ? "partial" : "success",
146+
}
147+
148+
console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`)
149+
return allModels
150+
} catch (error) {
151+
console.error("Error fetching Hugging Face models:", error)
152+
153+
// Return cached data if available
154+
if (cache) {
155+
console.log("Using stale cached data due to fetch error")
156+
cache.status = "error"
157+
return cache.data
158+
}
159+
160+
// No cache available, return empty array
161+
return []
162+
}
163+
}
164+
165+
export function getCachedModels(): HuggingFaceModel[] | null {
166+
return cache?.data || null
167+
}
168+
169+
export function clearCache(): void {
170+
cache = null
171+
}

0 commit comments

Comments
 (0)