Skip to content

Commit d084ae2

Browse files
SBrandeisaykutkardas
authored andcommitted
[Inference] [Providers] Enforce task in mapping + expose them (huggingface#1109)
- Add task metadata to the HF id -> Provider id mappings, to forbid the usage of a chat model with the `textToImage` inference function for example - Expose the supported models mappings in `index.ts`
1 parent fbd01bc commit d084ae2

File tree

8 files changed

+200
-173
lines changed

8 files changed

+200
-173
lines changed

packages/inference/src/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
export { HfInference, HfInferenceEndpoint } from "./HfInference";
22
export { InferenceOutputError } from "./lib/InferenceOutputError";
3+
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
4+
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
5+
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
6+
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
37
export * from "./types";
48
export * from "./tasks";

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
2-
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
3-
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
4-
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
5-
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
2+
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
3+
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
4+
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
5+
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
66
import type { InferenceProvider } from "../types";
77
import type { InferenceTask, Options, RequestArgs } from "../types";
88
import { isUrl } from "./isUrl";
@@ -50,13 +50,13 @@ export async function makeRequestOptions(
5050
let model: string;
5151
if (!maybeModel) {
5252
if (taskHint) {
53-
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
53+
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
5454
} else {
5555
throw new Error("No model provided, and no default model found for this task");
5656
/// TODO : change error message ^
5757
}
5858
} else {
59-
model = mapModel({ model: maybeModel, provider });
59+
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
6060
}
6161

6262
/// If accessToken is passed, it should take precedence over includeCredentials
@@ -143,24 +143,34 @@ export async function makeRequestOptions(
143143
return { url, info };
144144
}
145145

146-
function mapModel(params: { model: string; provider: InferenceProvider }): string {
146+
function mapModel(params: {
147+
model: string;
148+
provider: InferenceProvider;
149+
taskHint: InferenceTask | undefined;
150+
chatCompletion: boolean | undefined;
151+
}): string {
152+
if (params.provider === "hf-inference") {
153+
return params.model;
154+
}
155+
if (!params.taskHint) {
156+
throw new Error("taskHint must be specified when using a third-party provider");
157+
}
158+
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
147159
const model = (() => {
148160
switch (params.provider) {
149161
case "fal-ai":
150-
return FAL_AI_MODEL_IDS[params.model];
162+
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
151163
case "replicate":
152-
return REPLICATE_MODEL_IDS[params.model];
164+
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
153165
case "sambanova":
154-
return SAMBANOVA_MODEL_IDS[params.model];
166+
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
155167
case "together":
156-
return TOGETHER_MODEL_IDS[params.model]?.id;
157-
case "hf-inference":
158-
return params.model;
168+
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
159169
}
160170
})();
161171

162172
if (!model) {
163-
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
173+
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
164174
}
165175
return model;
166176
}
Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,76 @@
1-
import type { ModelId } from "../types";
1+
import { ModelId } from "../types";
2+
import type { ProviderMapping } from "./types";
23

34
export const FAL_AI_API_BASE_URL = "https://fal.run";
45

56
type FalAiId = string;
67

7-
/**
8-
* Mapping from HF model ID -> fal.ai app id
9-
*/
10-
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
11-
/** text-to-image */
12-
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
13-
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
14-
"black-forest-labs/FLUX.1-Depth-dev": "fal-ai/flux-lora-depth",
15-
"black-forest-labs/FLUX.1-Canny-dev": "fal-ai/flux-lora-canny",
16-
"black-forest-labs/FLUX.1-Redux-dev": "fal-ai/flux/dev/redux",
17-
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
18-
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
19-
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
20-
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
21-
"Warlord-K/Sana-1024": "fal-ai/sana",
22-
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
23-
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
24-
"yresearch/Switti": "fal-ai/switti",
25-
"guozinan/PuLID": "fal-ai/flux-pulid",
26-
"lllyasviel/ic-light": "fal-ai/iclight-v2",
27-
"stabilityai/stable-diffusion-xl-base-1.0": "fal-ai/lora",
28-
"Kwai-Kolors/Kolors": "fal-ai/kolors",
29-
30-
/** image-to-image */
31-
"Yuanshi/OminiControl": "fal-ai/flux-subject",
32-
"fal/AuraSR-v2": "fal-ai/aura-sr",
33-
"franciszzj/Leffa": "fal-ai/leffa",
34-
"ai-forever/Real-ESRGAN": "fal-ai/esrgan",
35-
36-
/** image-segmentation */
37-
"briaai/RMBG-2.0": "fal-ai/bria/background/remove",
38-
"ZhengPeng7/BiRefNet": "fal-ai/birefnet/v2",
39-
40-
/** text-to-video */
41-
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
42-
"THUDM/CogVideoX-5b": "fal-ai/cogvideox-5b",
43-
"Lightricks/LTX-Video": "fal-ai/ltx-video",
44-
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
45-
"wileewang/TransPixar": "fal-ai/transpixar",
46-
47-
/** image-to-video */
48-
"stabilityai/stable-video-1.0": "fal-ai/stable-video",
49-
"KwaiVGI/LivePortrait": "fal-ai/live-portrait",
50-
51-
/** text-to-audio */
52-
"hkchengrex/MMAudio": "fal-ai/mmaudio-v2",
53-
"stabilityai/stable-audio-open-1.0": "fal-ai/stable-audio",
54-
55-
/** text-to-speech */
56-
"SWivid/F5-TTS": "fal-ai/f5-tts",
57-
58-
/** image-text-to-text */
59-
"vikhyatk/moondream-next": "fal-ai/moondream-next",
60-
"microsoft/Florence-2-large": "fal-ai/florence-2-large/caption",
61-
"ByteDance/Sa2VA-8B": "fal-ai/sa2va/8b/image/playground",
62-
63-
/** mask-generation */
64-
"facebook/sam2-hiera-large": "fal-ai/sam2",
65-
66-
/** image-to-3d */
67-
"JeffreyXiang/TRELLIS-image-large": "fal-ai/trellis",
68-
69-
/** depth-estimation */
70-
"Intel/dpt-hybrid-midas": "fal-ai/imageutils/depth",
71-
"prs-eth/marigold-depth-v1-0": "fal-ai/imageutils/marigold-depth",
72-
"depth-anything/Depth-Anything-V2-Large": "fal-ai/image-preprocessors/depth-anything/v2",
73-
74-
/** automatic-speech-recognition */
75-
"openai/whisper-large-v3": "fal-ai/whisper",
8+
export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
9+
"text-to-image": {
10+
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
11+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
12+
"black-forest-labs/FLUX.1-Depth-dev": "fal-ai/flux-lora-depth",
13+
"black-forest-labs/FLUX.1-Canny-dev": "fal-ai/flux-lora-canny",
14+
"black-forest-labs/FLUX.1-Fill-dev": "fal-ai/flux-lora-fill",
15+
"black-forest-labs/FLUX.1-Redux-dev": "fal-ai/flux/dev/redux",
16+
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
17+
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
18+
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
19+
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
20+
"Warlord-K/Sana-1024": "fal-ai/sana",
21+
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
22+
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
23+
"yresearch/Switti": "fal-ai/switti",
24+
"guozinan/PuLID": "fal-ai/flux-pulid",
25+
"lllyasviel/ic-light": "fal-ai/iclight-v2",
26+
"stabilityai/stable-diffusion-xl-base-1.0": "fal-ai/lora",
27+
"Kwai-Kolors/Kolors": "fal-ai/kolors"
28+
},
29+
"image-to-image": {
30+
"Yuanshi/OminiControl": "fal-ai/flux-subject",
31+
"fal/AuraSR-v2": "fal-ai/aura-sr",
32+
"franciszzj/Leffa": "fal-ai/leffa",
33+
"ai-forever/Real-ESRGAN": "fal-ai/esrgan"
34+
},
35+
"image-segmentation": {
36+
"briaai/RMBG-2.0": "fal-ai/bria/background/remove",
37+
"ZhengPeng7/BiRefNet": "fal-ai/birefnet/v2"
38+
},
39+
"text-to-video": {
40+
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
41+
"THUDM/CogVideoX-5b": "fal-ai/cogvideox-5b",
42+
"Lightricks/LTX-Video": "fal-ai/ltx-video",
43+
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
44+
"wileewang/TransPixar": "fal-ai/transpixar"
45+
},
46+
"image-to-video": {
47+
"stabilityai/stable-video-1.0": "fal-ai/stable-video",
48+
"KwaiVGI/LivePortrait": "fal-ai/live-portrait"
49+
},
50+
"text-to-audio": {
51+
"hkchengrex/MMAudio": "fal-ai/mmaudio-v2",
52+
"stabilityai/stable-audio-open-1.0": "fal-ai/stable-audio"
53+
},
54+
"text-to-speech": {
55+
"SWivid/F5-TTS": "fal-ai/f5-tts"
56+
},
57+
"image-text-to-text": {
58+
"vikhyatk/moondream-next": "fal-ai/moondream-next",
59+
"microsoft/Florence-2-large": "fal-ai/florence-2-large/caption",
60+
"ByteDance/Sa2VA-8B": "fal-ai/sa2va/8b/image/playground"
61+
},
62+
"mask-generation": {
63+
"facebook/sam2-hiera-large": "fal-ai/sam2"
64+
},
65+
"image-to-3d": {
66+
"JeffreyXiang/TRELLIS-image-large": "fal-ai/trellis"
67+
},
68+
"depth-estimation": {
69+
"Intel/dpt-hybrid-midas": "fal-ai/imageutils/depth",
70+
"prs-eth/marigold-depth-v1-0": "fal-ai/imageutils/marigold-depth",
71+
"depth-anything/Depth-Anything-V2-Large": "fal-ai/image-preprocessors/depth-anything/v2"
72+
},
73+
"automatic-speech-recognition": {
74+
"openai/whisper-large-v3": "fal-ai/whisper"
75+
}
7676
};
Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
1-
import type { ModelId } from "../types";
1+
import type { ProviderMapping } from "./types";
22

33
export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
44

55
type ReplicateId = string;
66

7-
/**
8-
* Mapping from HF model ID -> Replicate model ID
9-
*
10-
* Available models can be fetched with:
11-
* ```
12-
* curl -s \
13-
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
14-
* 'https://api.replicate.com/v1/models'
15-
* ```
16-
*/
17-
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
18-
/** text-to-image */
19-
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
20-
"ByteDance/SDXL-Lightning":
21-
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
7+
export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
8+
"text-to-image": {
9+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
10+
"ByteDance/SDXL-Lightning":
11+
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
12+
},
13+
// "text-to-speech": {
14+
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e"
15+
// },
2216
};
Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,23 @@
1-
import type { ModelId } from "../types";
1+
import type { ProviderMapping } from "./types";
22

33
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
44

5-
/**
6-
* Note for reviewers: our goal would be to ask Sambanova to support
7-
* our model ids too, so we don't have to define a mapping
8-
* or keep it up-to-date.
9-
*
10-
* As a fallback, if the above is not possible, ask Sambanova to
11-
* provide the mapping as an fetchable API.
12-
*/
135
type SambanovaId = string;
146

15-
/**
16-
* https://community.sambanova.ai/t/supported-models/193
17-
*/
18-
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
7+
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
198
/** Chat completion / conversational */
20-
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
21-
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
22-
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
23-
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
24-
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
25-
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
26-
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
27-
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
28-
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
29-
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
30-
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
31-
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
9+
conversational: {
10+
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
11+
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
12+
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
13+
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
14+
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
15+
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
16+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
17+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
18+
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
19+
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
20+
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
21+
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
22+
},
3223
};

0 commit comments

Comments
 (0)