Skip to content

Commit 38d13dd

Browse files
WauplinSBrandeisjulien-c
authored
[inference] Dynamic inference provider mapping (#1173)
_"it should work" 😄_ Still a draft while huggingface-internal/moon-landing#12398 (internal) is been merged/deployed. Goal is to use the dynamic mapping, and default back to hardcoded model ids if necessary (for backward compatibility). I haven't tested anything for now and I left some todos to address: - [x] what to do with `status: "live" | "staging"` ? => raise a warning - [ ] we need to cache the `modelInfo` call (only do it once at runtime) - [ ] how to handle if model supports both `text-generation` and `conversational` - [ ] how to deal with `"hf-inference"` if no taskHints? (for now, I kept as before) - [ ] tests are flaky (requires server-side update) **EDIT:** made an update to preserve previous behavior with hardcoded mapping. Dynamic mapping from the hub takes precedence. --------- Co-authored-by: SBrandeis <[email protected]> Co-authored-by: Julien Chaumond <[email protected]>
1 parent 671d03a commit 38d13dd

36 files changed

+260
-289
lines changed

packages/hub/src/lib/list-models.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export const MODEL_EXPANDABLE_KEYS = [
2525
"downloadsAllTime",
2626
"gated",
2727
"gitalyUid",
28+
"inferenceProviderMapping",
2829
"lastModified",
2930
"library_name",
3031
"likes",

packages/hub/src/types/api/api-model.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { ModelLibraryKey, TransformersInfo } from "@huggingface/tasks";
1+
import type { ModelLibraryKey, TransformersInfo, WidgetType } from "@huggingface/tasks";
22
import type { License, PipelineType } from "../public";
33

44
export interface ApiModelInfo {
@@ -18,6 +18,9 @@ export interface ApiModelInfo {
1818
downloadsAllTime: number;
1919
files: string[];
2020
gitalyUid: string;
21+
inferenceProviderMapping: Partial<
22+
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
23+
>;
2124
lastAuthor: { email: string; user?: string };
2225
lastModified: string; // convert to date
2326
library_name?: ModelLibraryKey;

packages/inference/src/index.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
export type { ProviderMapping } from "./providers/types";
21
export { HfInference, HfInferenceEndpoint } from "./HfInference";
32
export { InferenceOutputError } from "./lib/InferenceOutputError";
4-
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
5-
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
6-
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
7-
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
83
export * from "./types";
94
export * from "./tasks";
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types";
3+
import { HF_HUB_URL } from "../config";
4+
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";
5+
6+
type InferenceProviderMapping = Partial<
7+
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
8+
>;
9+
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
10+
11+
export async function getProviderModelId(
12+
params: {
13+
model: string;
14+
provider: InferenceProvider;
15+
},
16+
args: RequestArgs,
17+
options: {
18+
taskHint?: InferenceTask;
19+
chatCompletion?: boolean;
20+
fetch?: Options["fetch"];
21+
} = {}
22+
): Promise<string> {
23+
if (params.provider === "hf-inference") {
24+
return params.model;
25+
}
26+
if (!options.taskHint) {
27+
throw new Error("taskHint must be specified when using a third-party provider");
28+
}
29+
const task: WidgetType =
30+
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;
31+
32+
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
33+
if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
34+
return HARDCODED_MODEL_ID_MAPPING[params.model];
35+
}
36+
37+
let inferenceProviderMapping: InferenceProviderMapping | null;
38+
if (inferenceProviderMappingCache.has(params.model)) {
39+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40+
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
41+
} else {
42+
inferenceProviderMapping = await (options?.fetch ?? fetch)(
43+
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
44+
{
45+
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
46+
}
47+
)
48+
.then((resp) => resp.json())
49+
.then((json) => json.inferenceProviderMapping)
50+
.catch(() => null);
51+
}
52+
53+
if (!inferenceProviderMapping) {
54+
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
55+
}
56+
57+
const providerMapping = inferenceProviderMapping[params.provider];
58+
if (providerMapping) {
59+
if (providerMapping.task !== task) {
60+
throw new Error(
61+
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
62+
);
63+
}
64+
if (providerMapping.status === "staging") {
65+
console.warn(
66+
`Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
67+
);
68+
}
69+
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
70+
return providerMapping.providerId;
71+
}
72+
73+
throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
74+
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import type { WidgetType } from "@huggingface/tasks";
21
import { HF_HUB_URL } from "../config";
3-
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
4-
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
5-
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
6-
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
2+
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3+
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
4+
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
5+
import { TOGETHER_API_BASE_URL } from "../providers/together";
76
import type { InferenceProvider } from "../types";
87
import type { InferenceTask, Options, RequestArgs } from "../types";
98
import { isUrl } from "./isUrl";
109
import { version as packageVersion, name as packageName } from "../../package.json";
10+
import { getProviderModelId } from "./getProviderModelId";
1111

1212
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;
1313

@@ -49,18 +49,16 @@ export async function makeRequestOptions(
4949
if (maybeModel && isUrl(maybeModel)) {
5050
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
5151
}
52-
53-
let model: string;
54-
if (!maybeModel) {
55-
if (taskHint) {
56-
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
57-
} else {
58-
throw new Error("No model provided, and no default model found for this task");
59-
/// TODO : change error message ^
60-
}
61-
} else {
62-
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
52+
if (!maybeModel && !taskHint) {
53+
throw new Error("No model provided, and no task has been specified.");
6354
}
55+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
56+
const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
57+
const model = await getProviderModelId({ model: hfModel, provider }, args, {
58+
taskHint,
59+
chatCompletion,
60+
fetch: options?.fetch,
61+
});
6462

6563
/// If accessToken is passed, it should take precedence over includeCredentials
6664
const authMethod = accessToken
@@ -153,39 +151,6 @@ export async function makeRequestOptions(
153151
return { url, info };
154152
}
155153

156-
function mapModel(params: {
157-
model: string;
158-
provider: InferenceProvider;
159-
taskHint: InferenceTask | undefined;
160-
chatCompletion: boolean | undefined;
161-
}): string {
162-
if (params.provider === "hf-inference") {
163-
return params.model;
164-
}
165-
if (!params.taskHint) {
166-
throw new Error("taskHint must be specified when using a third-party provider");
167-
}
168-
const task: WidgetType =
169-
params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
170-
const model = (() => {
171-
switch (params.provider) {
172-
case "fal-ai":
173-
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
174-
case "replicate":
175-
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
176-
case "sambanova":
177-
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
178-
case "together":
179-
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
180-
}
181-
})();
182-
183-
if (!model) {
184-
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
185-
}
186-
return model;
187-
}
188-
189154
function makeUrl(params: {
190155
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
191156
chatCompletion: boolean;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import type { ModelId } from "../types";
2+
3+
type ProviderId = string;
4+
5+
/**
6+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
7+
* for a given Inference Provider,
8+
* you can add it to the following dictionary, for dev purposes.
9+
*/
10+
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
11+
/**
12+
* "HF model ID" => "Model ID on Inference Provider's side"
13+
*/
14+
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
15+
};
Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,18 @@
1-
import type { ProviderMapping } from "./types";
2-
31
export const FAL_AI_API_BASE_URL = "https://fal.run";
42

5-
type FalAiId = string;
6-
7-
export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
8-
"text-to-image": {
9-
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
10-
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
11-
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
12-
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
13-
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
14-
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
15-
"Warlord-K/Sana-1024": "fal-ai/sana",
16-
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
17-
"stabilityai/stable-diffusion-xl-base-1.0": "fal-ai/fast-sdxl",
18-
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
19-
"stabilityai/stable-diffusion-3.5-large-turbo": "fal-ai/stable-diffusion-v35-large/turbo",
20-
"stabilityai/stable-diffusion-3.5-medium": "fal-ai/stable-diffusion-v35-medium",
21-
"Kwai-Kolors/Kolors": "fal-ai/kolors",
22-
},
23-
"automatic-speech-recognition": {
24-
"openai/whisper-large-v3": "fal-ai/whisper",
25-
},
26-
"text-to-video": {
27-
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
28-
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
29-
"THUDM/CogVideoX-5b": "fal-ai/cogvideox-5b",
30-
"Lightricks/LTX-Video": "fal-ai/ltx-video",
31-
},
32-
};
3+
/**
4+
* See the registered mapping of HF model ID => Fal model ID here:
5+
*
6+
* https://huggingface.co/api/partners/fal-ai/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Fal and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Fal, please open an issue on the present repo
15+
* and we will tag Fal team members.
16+
*
17+
* Thanks!
18+
*/
Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,18 @@
1-
import type { ProviderMapping } from "./types";
2-
31
export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
42

5-
type ReplicateId = string;
6-
7-
export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
8-
"text-to-image": {
9-
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev",
10-
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
11-
"ByteDance/Hyper-SD":
12-
"bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
13-
"ByteDance/SDXL-Lightning":
14-
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
15-
"playgroundai/playground-v2.5-1024px-aesthetic":
16-
"playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
17-
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
18-
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large",
19-
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium",
20-
"stabilityai/stable-diffusion-xl-base-1.0":
21-
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
22-
},
23-
"text-to-speech": {
24-
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3",
25-
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6",
26-
},
27-
"text-to-video": {
28-
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
29-
},
30-
};
3+
/**
4+
* See the registered mapping of HF model ID => Replicate model ID here:
5+
*
6+
* https://huggingface.co/api/partners/replicate/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Replicate and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Replicate, please open an issue on the present repo
15+
* and we will tag Replicate team members.
16+
*
17+
* Thanks!
18+
*/
Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
1-
import type { ProviderMapping } from "./types";
2-
31
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
42

5-
type SambanovaId = string;
6-
7-
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
8-
/** Chat completion / conversational */
9-
conversational: {
10-
"allenai/Llama-3.1-Tulu-3-405B":"Llama-3.1-Tulu-3-405B",
11-
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B": "DeepSeek-R1-Distill-Llama-70B",
12-
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
13-
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
14-
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
15-
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
16-
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
17-
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
18-
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
19-
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
20-
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
21-
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
22-
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
23-
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
24-
},
25-
};
3+
/**
4+
* See the registered mapping of HF model ID => Sambanova model ID here:
5+
*
6+
* https://huggingface.co/api/partners/sambanova/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Sambanova and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Sambanova, please open an issue on the present repo
15+
* and we will tag Sambanova team members.
16+
*
17+
* Thanks!
18+
*/

0 commit comments

Comments
 (0)