Skip to content

Commit a170a28

Browse files
committed
Harmonize inferenceProviderMapping additional parameter in modelInfo / listModel
1 parent 49d93f2 commit a170a28

File tree

6 files changed

+105
-5
lines changed

6 files changed

+105
-5
lines changed

packages/hub/src/lib/list-models.spec.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,23 @@ describe("listModels", () => {
115115

116116
expect(count).to.equal(10);
117117
});
118+
119+
it("should list deepseek-ai models with inference provider mapping", async () => {
120+
let count = 0;
121+
for await (const entry of listModels({
122+
search: { owner: "deepseek-ai" },
123+
additionalFields: ["inferenceProviderMapping"],
124+
limit: 1,
125+
})) {
126+
count++;
127+
expect(entry.inferenceProviderMapping).to.be.an("array").that.is.not.empty;
128+
for (const item of entry.inferenceProviderMapping ?? []) {
129+
expect(item).to.have.property("provider").that.is.a("string").and.is.not.empty;
130+
expect(item).to.have.property("hfModelId").that.is.a("string").and.is.not.empty;
131+
expect(item).to.have.property("providerId").that.is.a("string").and.is.not.empty;
132+
}
133+
}
134+
135+
expect(count).to.equal(1);
136+
});
118137
});

packages/hub/src/lib/list-models.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type { CredentialsParams, PipelineType } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { parseLinkHeader } from "../utils/parseLinkHeader";
77
import { pick } from "../utils/pick";
8+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
89

910
export const MODEL_EXPAND_KEYS = [
1011
"pipeline_tag",
@@ -113,8 +114,20 @@ export async function* listModels<
113114
const items: ApiModelInfo[] = await res.json();
114115

115116
for (const item of items) {
117+
// Handle inferenceProviderMapping normalization
118+
const normalizedItem = { ...item };
119+
if (
120+
(params?.additionalFields as string[])?.includes("inferenceProviderMapping") &&
121+
item.inferenceProviderMapping
122+
) {
123+
normalizedItem.inferenceProviderMapping = normalizeInferenceProviderMapping(
124+
item.id,
125+
item.inferenceProviderMapping
126+
);
127+
}
128+
116129
yield {
117-
...(params?.additionalFields && pick(item, params.additionalFields)),
130+
...(params?.additionalFields && pick(normalizedItem, params.additionalFields)),
118131
id: item._id,
119132
name: item.id,
120133
private: item.private,

packages/hub/src/lib/model-info.spec.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,20 @@ describe("modelInfo", () => {
5656
sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
5757
});
5858
});
59+
60+
it("should return model info deepseek-ai models with inference provider mapping", async () => {
61+
const info = await modelInfo({
62+
name: "deepseek-ai/DeepSeek-R1-0528",
63+
additionalFields: ["inferenceProviderMapping"],
64+
});
65+
66+
expect(info.inferenceProviderMapping).toBeDefined();
67+
expect(info.inferenceProviderMapping).toBeInstanceOf(Array);
68+
expect(info.inferenceProviderMapping?.length).toBeGreaterThan(0);
69+
info.inferenceProviderMapping?.forEach((item) => {
70+
expect(item).toHaveProperty("provider");
71+
expect(item).toHaveProperty("hf_model_id", "deepseek-ai/DeepSeek-R1-0528");
72+
expect(item).toHaveProperty("provider_id");
73+
});
74+
});
5975
});

packages/hub/src/lib/model-info.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
44
import type { CredentialsParams } from "../types/public";
55
import { checkCredentials } from "../utils/checkCredentials";
66
import { pick } from "../utils/pick";
7+
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
78
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";
89

910
export async function modelInfo<
@@ -48,8 +49,14 @@ export async function modelInfo<
4849

4950
const data = await response.json();
5051

52+
// Handle inferenceProviderMapping normalization
53+
const normalizedData = { ...data };
54+
if ((params?.additionalFields as string[])?.includes("inferenceProviderMapping") && data.inferenceProviderMapping) {
55+
normalizedData.inferenceProviderMapping = normalizeInferenceProviderMapping(data.id, data.inferenceProviderMapping);
56+
}
57+
5158
return {
52-
...(params?.additionalFields && pick(data, params.additionalFields)),
59+
...(params?.additionalFields && pick(normalizedData, params.additionalFields)),
5360
id: data._id,
5461
name: data.id,
5562
private: data.private,

packages/hub/src/types/api/api-model.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ export interface ApiModelInfo {
1818
downloadsAllTime: number;
1919
files: string[];
2020
gitalyUid: string;
21-
inferenceProviderMapping: Partial<
22-
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
23-
>;
21+
inferenceProviderMapping?: ApiModelInferenceProviderMappingEntry[];
2422
lastAuthor: { email: string; user?: string };
2523
lastModified: string; // convert to date
2624
library_name?: ModelLibraryKey;
@@ -271,3 +269,14 @@ export interface ApiModelMetadata {
271269
extra_gated_description?: string;
272270
extra_gated_button_content?: string;
273271
}
272+
273+
export interface ApiModelInferenceProviderMappingEntry {
274+
provider: string; // Provider name
275+
hf_model_id: string; // ID of the model on the Hugging Face Hub
276+
provider_id: string; // ID of the model on the provider's side
277+
status: "live" | "staging";
278+
task: WidgetType;
279+
adapter?: string;
280+
adapter_weights_path?: string;
281+
type?: "single-file" | "tag-filter";
282+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { ApiModelInferenceProviderMappingEntry } from "../types/api/api-model";
3+
4+
/**
5+
* Normalize inferenceProviderMapping to always return an array format.
6+
*
7+
* Little hack to simplify Inference Providers logic and make it backward and forward compatible.
8+
* Right now, API returns a dict on model-info and a list on list-models. Let's harmonize to list.
9+
*/
10+
export function normalizeInferenceProviderMapping(
11+
hf_model_id: string,
12+
inferenceProviderMapping?:
13+
| ApiModelInferenceProviderMappingEntry[]
14+
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
15+
): ApiModelInferenceProviderMappingEntry[] {
16+
if (!inferenceProviderMapping) {
17+
return [];
18+
}
19+
20+
// If it's already an array, return it as is
21+
if (Array.isArray(inferenceProviderMapping)) {
22+
return inferenceProviderMapping.map((entry) => ({
23+
...entry,
24+
hf_model_id,
25+
}));
26+
}
27+
28+
// Convert mapping to array format
29+
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
30+
provider,
31+
hf_model_id,
32+
provider_id: mapping.providerId,
33+
status: mapping.status,
34+
task: mapping.task,
35+
}));
36+
}

0 commit comments

Comments
 (0)