Skip to content

Commit 98465e2

Browse files
authored
Merge branch 'main' into feat/wavespeedai
2 parents fd20f75 + 4e05d9e commit 98465e2

35 files changed

+409
-55
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9797

9898
```html
9999
<script type="module">
100-
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected].3/+esm';
100+
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected].4/+esm';
101101
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
102102
</script>
103103
```

packages/hub/src/lib/list-models.spec.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,23 @@ describe("listModels", () => {
115115

116116
expect(count).to.equal(10);
117117
});
118+
119+
it("should list deepseek-ai models with inference provider mapping", async () => {
120+
let count = 0;
121+
for await (const entry of listModels({
122+
search: { owner: "deepseek-ai" },
123+
additionalFields: ["inferenceProviderMapping"],
124+
limit: 1,
125+
})) {
126+
count++;
127+
expect(entry.inferenceProviderMapping).to.be.an("array").that.is.not.empty;
128+
for (const item of entry.inferenceProviderMapping ?? []) {
129+
expect(item).to.have.property("provider").that.is.a("string").and.is.not.empty;
130+
expect(item).to.have.property("hfModelId").that.is.a("string").and.is.not.empty;
131+
expect(item).to.have.property("providerId").that.is.a("string").and.is.not.empty;
132+
}
133+
}
134+
135+
expect(count).to.equal(1);
136+
});
118137
});

packages/hub/src/lib/list-models.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type { CredentialsParams, PipelineType } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
77
import { pick } from "../utils/pick";
8+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
89

910
export const MODEL_EXPAND_KEYS = [
1011
"pipeline_tag",
@@ -113,8 +114,20 @@ export async function* listModels<
113114
const items: ApiModelInfo[] = await res.json();
114115

115116
for (const item of items) {
117+
// Handle inferenceProviderMapping normalization
118+
const normalizedItem = { ...item };
119+
if (
120+
(params?.additionalFields as string[])?.includes("inferenceProviderMapping") &&
121+
item.inferenceProviderMapping
122+
) {
123+
normalizedItem.inferenceProviderMapping = normalizeInferenceProviderMapping(
124+
item.id,
125+
item.inferenceProviderMapping
126+
);
127+
}
128+
116129
yield {
117-
...(params?.additionalFields && pick(item, params.additionalFields)),
130+
...(params?.additionalFields && pick(normalizedItem, params.additionalFields)),
118131
id: item._id,
119132
name: item.id,
120133
private: item.private,

packages/hub/src/lib/model-info.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,20 @@ describe("modelInfo", () => {
5656
sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
5757
});
5858
});
59+
60+
it("should return model info deepseek-ai models with inference provider mapping", async () => {
61+
const info = await modelInfo({
62+
name: "deepseek-ai/DeepSeek-R1-0528",
63+
additionalFields: ["inferenceProviderMapping"],
64+
});
65+
66+
expect(info.inferenceProviderMapping).toBeDefined();
67+
expect(info.inferenceProviderMapping).toBeInstanceOf(Array);
68+
expect(info.inferenceProviderMapping?.length).toBeGreaterThan(0);
69+
info.inferenceProviderMapping?.forEach((item) => {
70+
expect(item).toHaveProperty("provider");
71+
expect(item).toHaveProperty("hfModelId", "deepseek-ai/DeepSeek-R1-0528");
72+
expect(item).toHaveProperty("providerId");
73+
});
74+
});
5975
});

packages/hub/src/lib/model-info.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
44
import type { CredentialsParams } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { pick } from "../utils/pick";
7+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
78
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";
89

910
export async function modelInfo<
@@ -48,8 +49,14 @@ export async function modelInfo<
4849

4950
const data = await response.json();
5051

52+
// Handle inferenceProviderMapping normalization
53+
const normalizedData = { ...data };
54+
if ((params?.additionalFields as string[])?.includes("inferenceProviderMapping") && data.inferenceProviderMapping) {
55+
normalizedData.inferenceProviderMapping = normalizeInferenceProviderMapping(data.id, data.inferenceProviderMapping);
56+
}
57+
5158
return {
52-
...(params?.additionalFields && pick(data, params.additionalFields)),
59+
...(params?.additionalFields && pick(normalizedData, params.additionalFields)),
5360
id: data._id,
5461
name: data.id,
5562
private: data.private,

packages/hub/src/types/api/api-model.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ export interface ApiModelInfo {
1818
downloadsAllTime: number;
1919
files: string[];
2020
gitalyUid: string;
21-
inferenceProviderMapping: Partial<
22-
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
23-
>;
21+
inferenceProviderMapping?: ApiModelInferenceProviderMappingEntry[];
2422
lastAuthor: { email: string; user?: string };
2523
lastModified: string; // convert to date
2624
library_name?: ModelLibraryKey;
@@ -271,3 +269,14 @@ export interface ApiModelMetadata {
271269
extra_gated_description?: string;
272270
extra_gated_button_content?: string;
273271
}
272+
273+
export interface ApiModelInferenceProviderMappingEntry {
274+
provider: string; // Provider name
275+
hfModelId: string; // ID of the model on the Hugging Face Hub
276+
providerId: string; // ID of the model on the provider's side
277+
status: "live" | "staging";
278+
task: WidgetType;
279+
adapter?: string;
280+
adapterWeightsPath?: string;
281+
type?: "single-file" | "tag-filter";
282+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { ApiModelInferenceProviderMappingEntry } from "../types/api/api-model";
3+
4+
/**
5+
* Normalize inferenceProviderMapping to always return an array format.
6+
*
7+
* Little hack to simplify Inference Providers logic and make it backward and forward compatible.
8+
* Right now, API returns a dict on model-info and a list on list-models. Let's harmonize to list.
9+
*/
10+
export function normalizeInferenceProviderMapping(
11+
hfModelId: string,
12+
inferenceProviderMapping?:
13+
| ApiModelInferenceProviderMappingEntry[]
14+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
15+
): ApiModelInferenceProviderMappingEntry[] {
16+
if (!inferenceProviderMapping) {
17+
return [];
18+
}
19+
20+
// If it's already an array, return it as is
21+
if (Array.isArray(inferenceProviderMapping)) {
22+
return inferenceProviderMapping.map((entry) => ({
23+
...entry,
24+
hfModelId,
25+
}));
26+
}
27+
28+
// Convert mapping to array format
29+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
30+
provider,
31+
hfModelId,
32+
providerId: mapping.providerId,
33+
status: mapping.status,
34+
task: mapping.task,
35+
}));
36+
}

packages/inference/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,10 @@ You can use any Chat Completion API-compatible provider with the `chatCompletion
653653
```typescript
654654
// Chat Completion Example
655655
const MISTRAL_KEY = process.env.MISTRAL_KEY;
656-
const hf = new InferenceClient(MISTRAL_KEY);
657-
const ep = hf.endpoint("https://api.mistral.ai");
658-
const stream = ep.chatCompletionStream({
656+
const hf = new InferenceClient(MISTRAL_KEY, {
657+
endpointUrl: "https://api.mistral.ai",
658+
});
659+
const stream = hf.chatCompletionStream({
659660
model: "mistral-tiny",
660661
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
661662
});

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "4.0.3",
3+
"version": "4.0.4",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Hugging Face and Tim Mikeladze <[email protected]>",

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,48 @@ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../t
66
import { typedInclude } from "../utils/typedInclude.js";
77
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";
88

9-
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
9+
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMappingEntry[]>();
1010

11-
export type InferenceProviderMapping = Partial<
12-
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
13-
>;
14-
15-
export interface InferenceProviderModelMapping {
11+
export interface InferenceProviderMappingEntry {
1612
adapter?: string;
1713
adapterWeightsPath?: string;
1814
hfModelId: ModelId;
15+
provider: string;
1916
providerId: string;
2017
status: "live" | "staging";
2118
task: WidgetType;
19+
type?: "single-model" | "tag-filter";
20+
}
21+
22+
/**
23+
* Normalize inferenceProviderMapping to always return an array format.
24+
* This provides backward and forward compatibility for the API changes.
25+
*
26+
* Vendored from @huggingface/hub to avoid extra dependency.
27+
*/
28+
function normalizeInferenceProviderMapping(
29+
modelId: ModelId,
30+
inferenceProviderMapping?:
31+
| InferenceProviderMappingEntry[]
32+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
33+
): InferenceProviderMappingEntry[] {
34+
if (!inferenceProviderMapping) {
35+
return [];
36+
}
37+
38+
// If it's already an array, return it as is
39+
if (Array.isArray(inferenceProviderMapping)) {
40+
return inferenceProviderMapping;
41+
}
42+
43+
// Convert mapping to array format
44+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
45+
provider,
46+
hfModelId: modelId,
47+
providerId: mapping.providerId,
48+
status: mapping.status,
49+
task: mapping.task,
50+
}));
2251
}
2352

2453
export async function fetchInferenceProviderMappingForModel(
@@ -27,8 +56,8 @@ export async function fetchInferenceProviderMappingForModel(
2756
options?: {
2857
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
2958
}
30-
): Promise<InferenceProviderMapping> {
31-
let inferenceProviderMapping: InferenceProviderMapping | null;
59+
): Promise<InferenceProviderMappingEntry[]> {
60+
let inferenceProviderMapping: InferenceProviderMappingEntry[] | null;
3261
if (inferenceProviderMappingCache.has(modelId)) {
3362
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
3463
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
@@ -55,7 +84,11 @@ export async function fetchInferenceProviderMappingForModel(
5584
);
5685
}
5786
}
58-
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
87+
let payload: {
88+
inferenceProviderMapping?:
89+
| InferenceProviderMappingEntry[]
90+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>;
91+
} | null = null;
5992
try {
6093
payload = await resp.json();
6194
} catch {
@@ -72,7 +105,8 @@ export async function fetchInferenceProviderMappingForModel(
72105
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
73106
);
74107
}
75-
inferenceProviderMapping = payload.inferenceProviderMapping;
108+
inferenceProviderMapping = normalizeInferenceProviderMapping(modelId, payload.inferenceProviderMapping);
109+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
76110
}
77111
return inferenceProviderMapping;
78112
}
@@ -87,16 +121,12 @@ export async function getInferenceProviderMapping(
87121
options: {
88122
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
89123
}
90-
): Promise<InferenceProviderModelMapping | null> {
124+
): Promise<InferenceProviderMappingEntry | null> {
91125
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
92126
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
93127
}
94-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
95-
params.modelId,
96-
params.accessToken,
97-
options
98-
);
99-
const providerMapping = inferenceProviderMapping[params.provider];
128+
const mappings = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
129+
const providerMapping = mappings.find((mapping) => mapping.provider === params.provider);
100130
if (providerMapping) {
101131
const equivalentTasks =
102132
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
@@ -112,7 +142,7 @@ export async function getInferenceProviderMapping(
112142
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
113143
);
114144
}
115-
return { ...providerMapping, hfModelId: params.modelId };
145+
return providerMapping;
116146
}
117147
return null;
118148
}
@@ -139,8 +169,8 @@ export async function resolveProvider(
139169
if (!modelId) {
140170
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
141171
}
142-
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
143-
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
172+
const mappings = await fetchInferenceProviderMappingForModel(modelId);
173+
provider = mappings[0]?.provider as InferenceProvider | undefined;
144174
}
145175
if (!provider) {
146176
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);

0 commit comments

Comments
 (0)