Skip to content

Commit 921ede7

Browse files
committed
fetch hf models and providers
1 parent be4fcfa commit 921ede7

File tree

6 files changed

+371
-8
lines changed

6 files changed

+371
-8
lines changed

src/api/huggingface-models.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models"
2+
3+
export interface HuggingFaceModelsResponse {
4+
models: HuggingFaceModel[]
5+
cached: boolean
6+
timestamp: number
7+
}
8+
9+
export async function getHuggingFaceModels(): Promise<HuggingFaceModelsResponse> {
10+
const models = await fetchHuggingFaceModels()
11+
12+
return {
13+
models,
14+
cached: false, // We could enhance this to track if data came from cache
15+
timestamp: Date.now(),
16+
}
17+
}

src/core/webview/webviewMessageHandler.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,22 @@ export const webviewMessageHandler = async (
674674
// TODO: Cache like we do for OpenRouter, etc?
675675
provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels })
676676
break
677+
case "requestHuggingFaceModels":
678+
try {
679+
const { getHuggingFaceModels } = await import("../../api/huggingface-models")
680+
const huggingFaceModelsResponse = await getHuggingFaceModels()
681+
provider.postMessageToWebview({
682+
type: "huggingFaceModels",
683+
huggingFaceModels: huggingFaceModelsResponse.models,
684+
})
685+
} catch (error) {
686+
console.error("Failed to fetch Hugging Face models:", error)
687+
provider.postMessageToWebview({
688+
type: "huggingFaceModels",
689+
huggingFaceModels: [],
690+
})
691+
}
692+
break
677693
case "openImage":
678694
openImage(message.text!, { values: message.values })
679695
break

src/services/huggingface-models.ts

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
export interface HuggingFaceModel {
2+
_id: string
3+
id: string
4+
inferenceProviderMapping: InferenceProviderMapping[]
5+
trendingScore: number
6+
config: ModelConfig
7+
tags: string[]
8+
pipeline_tag: "text-generation" | "image-text-to-text"
9+
library_name?: string
10+
}
11+
12+
export interface InferenceProviderMapping {
13+
provider: string
14+
providerId: string
15+
status: "live" | "staging" | "error"
16+
task: "conversational"
17+
}
18+
19+
export interface ModelConfig {
20+
architectures: string[]
21+
model_type: string
22+
tokenizer_config?: {
23+
chat_template?: string | Array<{ name: string; template: string }>
24+
model_max_length?: number
25+
}
26+
}
27+
28+
interface HuggingFaceApiParams {
29+
pipeline_tag?: "text-generation" | "image-text-to-text"
30+
filter: string
31+
inference_provider: string
32+
limit: number
33+
expand: string[]
34+
}
35+
36+
const DEFAULT_PARAMS: HuggingFaceApiParams = {
37+
filter: "conversational",
38+
inference_provider: "all",
39+
limit: 100,
40+
expand: [
41+
"inferenceProviderMapping",
42+
"config",
43+
"library_name",
44+
"pipeline_tag",
45+
"tags",
46+
"mask_token",
47+
"trendingScore",
48+
],
49+
}
50+
51+
const BASE_URL = "https://huggingface.co/api/models"
52+
const CACHE_DURATION = 1000 * 60 * 60 // 1 hour
53+
54+
interface CacheEntry {
55+
data: HuggingFaceModel[]
56+
timestamp: number
57+
status: "success" | "partial" | "error"
58+
}
59+
60+
let cache: CacheEntry | null = null
61+
62+
function buildApiUrl(params: HuggingFaceApiParams): string {
63+
const url = new URL(BASE_URL)
64+
65+
// Add simple params
66+
Object.entries(params).forEach(([key, value]) => {
67+
if (!Array.isArray(value)) {
68+
url.searchParams.append(key, String(value))
69+
}
70+
})
71+
72+
// Handle array params specially
73+
params.expand.forEach((item) => {
74+
url.searchParams.append("expand[]", item)
75+
})
76+
77+
return url.toString()
78+
}
79+
80+
const headers: HeadersInit = {
81+
"Upgrade-Insecure-Requests": "1",
82+
"Sec-Fetch-Dest": "document",
83+
"Sec-Fetch-Mode": "navigate",
84+
"Sec-Fetch-Site": "none",
85+
"Sec-Fetch-User": "?1",
86+
Priority: "u=0, i",
87+
Pragma: "no-cache",
88+
"Cache-Control": "no-cache",
89+
}
90+
91+
const requestInit: RequestInit = {
92+
credentials: "include",
93+
headers,
94+
method: "GET",
95+
mode: "cors",
96+
}
97+
98+
export async function fetchHuggingFaceModels(): Promise<HuggingFaceModel[]> {
99+
const now = Date.now()
100+
101+
// Check cache
102+
if (cache && now - cache.timestamp < CACHE_DURATION) {
103+
console.log("Using cached Hugging Face models")
104+
return cache.data
105+
}
106+
107+
try {
108+
console.log("Fetching Hugging Face models from API...")
109+
110+
// Fetch both text-generation and image-text-to-text models in parallel
111+
const [textGenResponse, imgTextResponse] = await Promise.allSettled([
112+
fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "text-generation" }), requestInit),
113+
fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "image-text-to-text" }), requestInit),
114+
])
115+
116+
let textGenModels: HuggingFaceModel[] = []
117+
let imgTextModels: HuggingFaceModel[] = []
118+
let hasErrors = false
119+
120+
// Process text-generation models
121+
if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) {
122+
textGenModels = await textGenResponse.value.json()
123+
} else {
124+
console.error("Failed to fetch text-generation models:", textGenResponse)
125+
hasErrors = true
126+
}
127+
128+
// Process image-text-to-text models
129+
if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) {
130+
imgTextModels = await imgTextResponse.value.json()
131+
} else {
132+
console.error("Failed to fetch image-text-to-text models:", imgTextResponse)
133+
hasErrors = true
134+
}
135+
136+
// Combine and filter models
137+
const allModels = [...textGenModels, ...imgTextModels]
138+
.filter((model) => model.inferenceProviderMapping.length > 0)
139+
.sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase()))
140+
141+
// Update cache
142+
cache = {
143+
data: allModels,
144+
timestamp: now,
145+
status: hasErrors ? "partial" : "success",
146+
}
147+
148+
console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`)
149+
return allModels
150+
} catch (error) {
151+
console.error("Error fetching Hugging Face models:", error)
152+
153+
// Return cached data if available
154+
if (cache) {
155+
console.log("Using stale cached data due to fetch error")
156+
cache.status = "error"
157+
return cache.data
158+
}
159+
160+
// No cache available, return empty array
161+
return []
162+
}
163+
}
164+
165+
export function getCachedModels(): HuggingFaceModel[] | null {
166+
return cache?.data || null
167+
}
168+
169+
export function clearCache(): void {
170+
cache = null
171+
}

src/shared/ExtensionMessage.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export interface ExtensionMessage {
6767
| "ollamaModels"
6868
| "lmStudioModels"
6969
| "vsCodeLmModels"
70+
| "huggingFaceModels"
7071
| "vsCodeLmApiAvailable"
7172
| "updatePrompt"
7273
| "systemPrompt"
@@ -135,6 +136,28 @@ export interface ExtensionMessage {
135136
ollamaModels?: string[]
136137
lmStudioModels?: string[]
137138
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
139+
huggingFaceModels?: Array<{
140+
_id: string
141+
id: string
142+
inferenceProviderMapping: Array<{
143+
provider: string
144+
providerId: string
145+
status: "live" | "staging" | "error"
146+
task: "conversational"
147+
}>
148+
trendingScore: number
149+
config: {
150+
architectures: string[]
151+
model_type: string
152+
tokenizer_config?: {
153+
chat_template?: string | Array<{ name: string; template: string }>
154+
model_max_length?: number
155+
}
156+
}
157+
tags: string[]
158+
pipeline_tag: "text-generation" | "image-text-to-text"
159+
library_name?: string
160+
}>
138161
mcpServers?: McpServer[]
139162
commits?: GitCommit[]
140163
listApiConfig?: ProviderSettingsEntry[]

src/shared/WebviewMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export interface WebviewMessage {
6767
| "requestOllamaModels"
6868
| "requestLmStudioModels"
6969
| "requestVsCodeLmModels"
70+
| "requestHuggingFaceModels"
7071
| "openImage"
7172
| "saveImage"
7273
| "openFile"

0 commit comments

Comments
 (0)