Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/hub/src/lib/list-models.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,23 @@ describe("listModels", () => {

expect(count).to.equal(10);
});

it("should list deepseek-ai models with inference provider mapping", async () => {
let count = 0;
for await (const entry of listModels({
search: { owner: "deepseek-ai" },
additionalFields: ["inferenceProviderMapping"],
limit: 1,
})) {
count++;
expect(entry.inferenceProviderMapping).to.be.an("array").that.is.not.empty;
for (const item of entry.inferenceProviderMapping ?? []) {
expect(item).to.have.property("provider").that.is.a("string").and.is.not.empty;
expect(item).to.have.property("hfModelId").that.is.a("string").and.is.not.empty;
expect(item).to.have.property("providerId").that.is.a("string").and.is.not.empty;
}
}

expect(count).to.equal(1);
});
});
15 changes: 14 additions & 1 deletion packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { CredentialsParams, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";

export const MODEL_EXPAND_KEYS = [
"pipeline_tag",
Expand Down Expand Up @@ -113,8 +114,20 @@ export async function* listModels<
const items: ApiModelInfo[] = await res.json();

for (const item of items) {
// Handle inferenceProviderMapping normalization
const normalizedItem = { ...item };
if (
(params?.additionalFields as string[])?.includes("inferenceProviderMapping") &&
item.inferenceProviderMapping
) {
normalizedItem.inferenceProviderMapping = normalizeInferenceProviderMapping(
item.id,
item.inferenceProviderMapping
);
}

yield {
...(params?.additionalFields && pick(item, params.additionalFields)),
...(params?.additionalFields && pick(normalizedItem, params.additionalFields)),
id: item._id,
name: item.id,
private: item.private,
Expand Down
16 changes: 16 additions & 0 deletions packages/hub/src/lib/model-info.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,20 @@ describe("modelInfo", () => {
sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
});
});

it("should return model info deepseek-ai models with inference provider mapping", async () => {
const info = await modelInfo({
name: "deepseek-ai/DeepSeek-R1-0528",
additionalFields: ["inferenceProviderMapping"],
});

expect(info.inferenceProviderMapping).toBeDefined();
expect(info.inferenceProviderMapping).toBeInstanceOf(Array);
expect(info.inferenceProviderMapping?.length).toBeGreaterThan(0);
info.inferenceProviderMapping?.forEach((item) => {
expect(item).toHaveProperty("provider");
expect(item).toHaveProperty("hfModelId", "deepseek-ai/DeepSeek-R1-0528");
expect(item).toHaveProperty("providerId");
});
});
});
9 changes: 8 additions & 1 deletion packages/hub/src/lib/model-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { pick } from "../utils/pick";
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";

export async function modelInfo<
Expand Down Expand Up @@ -48,8 +49,14 @@ export async function modelInfo<

const data = await response.json();

// Handle inferenceProviderMapping normalization
const normalizedData = { ...data };
if ((params?.additionalFields as string[])?.includes("inferenceProviderMapping") && data.inferenceProviderMapping) {
normalizedData.inferenceProviderMapping = normalizeInferenceProviderMapping(data.id, data.inferenceProviderMapping);
}

return {
...(params?.additionalFields && pick(data, params.additionalFields)),
...(params?.additionalFields && pick(normalizedData, params.additionalFields)),
id: data._id,
name: data.id,
private: data.private,
Expand Down
15 changes: 12 additions & 3 deletions packages/hub/src/types/api/api-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ export interface ApiModelInfo {
downloadsAllTime: number;
files: string[];
gitalyUid: string;
inferenceProviderMapping: Partial<
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
>;
inferenceProviderMapping?: ApiModelInferenceProviderMappingEntry[];
lastAuthor: { email: string; user?: string };
lastModified: string; // convert to date
library_name?: ModelLibraryKey;
Expand Down Expand Up @@ -271,3 +269,14 @@ export interface ApiModelMetadata {
extra_gated_description?: string;
extra_gated_button_content?: string;
}

export interface ApiModelInferenceProviderMappingEntry {
provider: string; // Provider name
hfModelId: string; // ID of the model on the Hugging Face Hub
providerId: string; // ID of the model on the provider's side
status: "live" | "staging";
task: WidgetType;
adapter?: string;
adapterWeightsPath?: string;
type?: "single-file" | "tag-filter";
}
36 changes: 36 additions & 0 deletions packages/hub/src/utils/normalizeInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import type { WidgetType } from "@huggingface/tasks";
import type { ApiModelInferenceProviderMappingEntry } from "../types/api/api-model";

/**
* Normalize inferenceProviderMapping to always return an array format.
*
* Little hack to simplify Inference Providers logic and make it backward and forward compatible.
* Right now, API returns a dict on model-info and a list on list-models. Let's harmonize to list.
*/
export function normalizeInferenceProviderMapping(
hfModelId: string,
inferenceProviderMapping?:
| ApiModelInferenceProviderMappingEntry[]
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
): ApiModelInferenceProviderMappingEntry[] {
if (!inferenceProviderMapping) {
return [];
}

// If it's already an array, return it as is
if (Array.isArray(inferenceProviderMapping)) {
return inferenceProviderMapping.map((entry) => ({
...entry,
hfModelId,
}));
}

// Convert mapping to array format
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
provider,
hfModelId,
providerId: mapping.providerId,
status: mapping.status,
task: mapping.task,
}));
}
70 changes: 50 additions & 20 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,48 @@ import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../t
import { typedInclude } from "../utils/typedInclude.js";
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMappingEntry[]>();

export type InferenceProviderMapping = Partial<
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
>;

export interface InferenceProviderModelMapping {
export interface InferenceProviderMappingEntry {
adapter?: string;
adapterWeightsPath?: string;
hfModelId: ModelId;
provider: string;
providerId: string;
status: "live" | "staging";
task: WidgetType;
type?: "single-model" | "tag-filter";
}

/**
* Normalize inferenceProviderMapping to always return an array format.
* This provides backward and forward compatibility for the API changes.
*
* Vendored from @huggingface/hub to avoid extra dependency.
*/
function normalizeInferenceProviderMapping(
modelId: ModelId,
inferenceProviderMapping?:
| InferenceProviderMappingEntry[]
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
): InferenceProviderMappingEntry[] {
if (!inferenceProviderMapping) {
return [];
}

// If it's already an array, return it as is
if (Array.isArray(inferenceProviderMapping)) {
return inferenceProviderMapping;
}

// Convert mapping to array format
return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({
provider,
hfModelId: modelId,
providerId: mapping.providerId,
status: mapping.status,
task: mapping.task,
}));
}

export async function fetchInferenceProviderMappingForModel(
Expand All @@ -27,8 +56,8 @@ export async function fetchInferenceProviderMappingForModel(
options?: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderMapping> {
let inferenceProviderMapping: InferenceProviderMapping | null;
): Promise<InferenceProviderMappingEntry[]> {
let inferenceProviderMapping: InferenceProviderMappingEntry[] | null;
if (inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
Expand All @@ -55,7 +84,11 @@ export async function fetchInferenceProviderMappingForModel(
);
}
}
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
let payload: {
inferenceProviderMapping?:
| InferenceProviderMappingEntry[]
| Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>;
} | null = null;
try {
payload = await resp.json();
} catch {
Expand All @@ -72,7 +105,8 @@ export async function fetchInferenceProviderMappingForModel(
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
inferenceProviderMapping = payload.inferenceProviderMapping;
inferenceProviderMapping = normalizeInferenceProviderMapping(modelId, payload.inferenceProviderMapping);
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
}
return inferenceProviderMapping;
}
Expand All @@ -87,16 +121,12 @@ export async function getInferenceProviderMapping(
options: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
): Promise<InferenceProviderMappingEntry | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
params.modelId,
params.accessToken,
options
);
const providerMapping = inferenceProviderMapping[params.provider];
const mappings = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
const providerMapping = mappings.find((mapping) => mapping.provider === params.provider);
if (providerMapping) {
const equivalentTasks =
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
Expand All @@ -112,7 +142,7 @@ export async function getInferenceProviderMapping(
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
);
}
return { ...providerMapping, hfModelId: params.modelId };
return providerMapping;
}
return null;
}
Expand All @@ -139,8 +169,8 @@ export async function resolveProvider(
if (!modelId) {
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
const mappings = await fetchInferenceProviderMappingForModel(modelId);
provider = mappings[0]?.provider as InferenceProvider | undefined;
}
if (!provider) {
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
Expand Down
7 changes: 4 additions & 3 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js";
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js";
import type { InferenceTask, Options, RequestArgs } from "../types.js";
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping.js";
import type { InferenceProviderMappingEntry } from "./getInferenceProviderMapping.js";
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
import type { getProviderHelper } from "./getProviderHelper.js";
import { isUrl } from "./isUrl.js";
Expand Down Expand Up @@ -64,14 +64,15 @@ export async function makeRequestOptions(

const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
? ({
provider: provider,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
providerId: removeProviderPrefix(maybeModel!, provider),
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
hfModelId: maybeModel!,
status: "live",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
} satisfies InferenceProviderModelMapping)
} satisfies InferenceProviderMappingEntry)
: await getInferenceProviderMapping(
{
modelId: hfModel,
Expand Down Expand Up @@ -109,7 +110,7 @@ export function makeRequestOptionsFromResolvedModel(
data?: Blob | ArrayBuffer;
stream?: boolean;
},
mapping: InferenceProviderModelMapping | undefined,
mapping: InferenceProviderMappingEntry | undefined,
options?: Options & {
task?: InferenceTask;
}
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
import type { InferenceProvider } from "../types.js";
import { type ModelId } from "../types.js";

Expand All @@ -11,7 +11,7 @@ import { type ModelId } from "../types.js";
*/
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
InferenceProvider,
Record<ModelId, InferenceProviderModelMapping>
Record<ModelId, InferenceProviderMappingEntry>
> = {
/**
* "HF model ID" => "Model ID on Inference Provider's side"
Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/snippets/getInferenceSnippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
} from "@huggingface/tasks";
import type { PipelineType, WidgetType } from "@huggingface/tasks";
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks";
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js";
import type { InferenceProviderMappingEntry } from "../lib/getInferenceProviderMapping.js";
import { getProviderHelper } from "../lib/getProviderHelper.js";
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.js";
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
Expand Down Expand Up @@ -136,7 +136,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
return (
model: ModelDataMinimal,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
inferenceProviderMapping?: InferenceProviderMappingEntry,
opts?: InferenceSnippetOptions
): InferenceSnippet[] => {
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
Expand Down Expand Up @@ -320,7 +320,7 @@ const snippets: Partial<
(
model: ModelDataMinimal,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
inferenceProviderMapping?: InferenceProviderMappingEntry,
opts?: InferenceSnippetOptions
) => InferenceSnippet[]
>
Expand Down Expand Up @@ -359,7 +359,7 @@ const snippets: Partial<
export function getInferenceSnippets(
model: ModelDataMinimal,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
inferenceProviderMapping?: InferenceProviderMappingEntry,
opts?: Record<string, unknown>
): InferenceSnippet[] {
return model.pipeline_tag && model.pipeline_tag in snippets
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
import type { InferenceProviderModelMapping } from "./lib/getInferenceProviderMapping.js";
import type { InferenceProviderMappingEntry } from "./lib/getInferenceProviderMapping.js";

/**
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
Expand Down Expand Up @@ -126,6 +126,6 @@ export interface UrlParams {
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
args: T;
model: string;
mapping?: InferenceProviderModelMapping | undefined;
mapping?: InferenceProviderMappingEntry | undefined;
task?: InferenceTask;
}
Loading
Loading