Skip to content

Commit 827093f

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 0e5a6c8 + 930e30b commit 827093f

File tree

92 files changed

+8787
-7825
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+8787
-7825
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ await uploadFile({
2727
}
2828
});
2929

30-
// Use HF Inference API, or external Inference Providers!
30+
// Use all supported Inference Providers!
3131

3232
await inference.chatCompletion({
3333
model: "meta-llama/Llama-3.1-8B-Instruct",
@@ -55,7 +55,7 @@ await inference.textToImage({
5555

5656
This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.
5757

58-
- [@huggingface/inference](packages/inference/README.md): Use HF Inference API (serverless), Inference Endpoints (dedicated) and all supported Inference Providers to make calls to 100,000+ Machine Learning models
58+
- [@huggingface/inference](packages/inference/README.md): Use all supported (serverless) Inference Providers or switch to Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
5959
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
6060
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
6161
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
@@ -97,7 +97,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9797

9898
```html
9999
<script type="module">
100-
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.7.1/+esm';
100+
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.8.1/+esm';
101101
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
102102
</script>
103103
```
@@ -128,18 +128,18 @@ import { InferenceClient } from "@huggingface/inference";
128128

129129
const HF_TOKEN = "hf_...";
130130

131-
const inference = new InferenceClient(HF_TOKEN);
131+
const client = new InferenceClient(HF_TOKEN);
132132

133133
// Chat completion API
134-
const out = await inference.chatCompletion({
134+
const out = await client.chatCompletion({
135135
model: "meta-llama/Llama-3.1-8B-Instruct",
136136
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
137137
max_tokens: 512
138138
});
139139
console.log(out.choices[0].message);
140140

141141
// Streaming chat completion API
142-
for await (const chunk of inference.chatCompletionStream({
142+
for await (const chunk of client.chatCompletionStream({
143143
model: "meta-llama/Llama-3.1-8B-Instruct",
144144
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
145145
max_tokens: 512
@@ -148,14 +148,14 @@ for await (const chunk of inference.chatCompletionStream({
148148
}
149149

150150
/// Using a third-party provider:
151-
await inference.chatCompletion({
151+
await client.chatCompletion({
152152
model: "meta-llama/Llama-3.1-8B-Instruct",
153153
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
154154
max_tokens: 512,
155155
provider: "sambanova", // or together, fal-ai, replicate, cohere …
156156
})
157157

158-
await inference.textToImage({
158+
await client.textToImage({
159159
model: "black-forest-labs/FLUX.1-dev",
160160
inputs: "a picture of a green bird",
161161
provider: "fal-ai",
@@ -164,7 +164,7 @@ await inference.textToImage({
164164

165165

166166
// You can also omit "model" to use the recommended model for the task
167-
await inference.translation({
167+
await client.translation({
168168
inputs: "My name is Wolfgang and I live in Amsterdam",
169169
parameters: {
170170
src_lang: "en",
@@ -173,17 +173,17 @@ await inference.translation({
173173
});
174174

175175
// pass multimodal files or URLs as inputs
176-
await inference.imageToText({
176+
await client.imageToText({
177177
model: 'nlpconnect/vit-gpt2-image-captioning',
178178
data: await (await fetch('https://picsum.photos/300/300')).blob(),
179179
})
180180

181181
// Using your own dedicated inference endpoint: https://hf.co/docs/inference-endpoints/
182-
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
183-
const { generated_text } = await gpt2.textGeneration({ inputs: 'The answer to the universe is' });
182+
const gpt2Client = client.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
183+
const { generated_text } = await gpt2Client.textGeneration({ inputs: 'The answer to the universe is' });
184184

185185
// Chat Completion
186-
const llamaEndpoint = inference.endpoint(
186+
const llamaEndpoint = client.endpoint(
187187
"https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct"
188188
);
189189
const out = await llamaEndpoint.chatCompletion({

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "3.7.1",
3+
"version": "3.8.1",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Hugging Face and Tim Mikeladze <[email protected]>",
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { InferenceProvider, ModelId } from "../types";
3+
import { HF_HUB_URL } from "../config";
4+
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
5+
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
6+
import { typedInclude } from "../utils/typedInclude";
7+
8+
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
9+
10+
export type InferenceProviderMapping = Partial<
11+
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId" | "adapterWeightsPath">>
12+
>;
13+
14+
export interface InferenceProviderModelMapping {
15+
adapter?: string;
16+
adapterWeightsPath?: string;
17+
hfModelId: ModelId;
18+
providerId: string;
19+
status: "live" | "staging";
20+
task: WidgetType;
21+
}
22+
23+
export async function getInferenceProviderMapping(
24+
params: {
25+
accessToken?: string;
26+
modelId: ModelId;
27+
provider: InferenceProvider;
28+
task: WidgetType;
29+
},
30+
options: {
31+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
32+
}
33+
): Promise<InferenceProviderModelMapping | null> {
34+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36+
}
37+
let inferenceProviderMapping: InferenceProviderMapping | null;
38+
if (inferenceProviderMappingCache.has(params.modelId)) {
39+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40+
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
41+
} else {
42+
const resp = await (options?.fetch ?? fetch)(
43+
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
44+
{
45+
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
46+
}
47+
);
48+
if (resp.status === 404) {
49+
throw new Error(`Model ${params.modelId} does not exist`);
50+
}
51+
inferenceProviderMapping = await resp
52+
.json()
53+
.then((json) => json.inferenceProviderMapping)
54+
.catch(() => null);
55+
}
56+
57+
if (!inferenceProviderMapping) {
58+
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
59+
}
60+
61+
const providerMapping = inferenceProviderMapping[params.provider];
62+
if (providerMapping) {
63+
const equivalentTasks =
64+
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
65+
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
66+
: [params.task];
67+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
68+
throw new Error(
69+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
70+
);
71+
}
72+
if (providerMapping.status === "staging") {
73+
console.warn(
74+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
75+
);
76+
}
77+
if (providerMapping.adapter === "lora") {
78+
const treeResp = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${params.modelId}/tree/main`);
79+
if (!treeResp.ok) {
80+
throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
81+
}
82+
const tree: Array<{ type: "file" | "directory"; path: string }> = await treeResp.json();
83+
const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
84+
if (!adapterWeightsPath) {
85+
throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
86+
}
87+
return {
88+
...providerMapping,
89+
hfModelId: params.modelId,
90+
adapterWeightsPath,
91+
};
92+
}
93+
return { ...providerMapping, hfModelId: params.modelId };
94+
}
95+
return null;
96+
}

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
124124
},
125125
sambanova: {
126126
conversational: new Sambanova.SambanovaConversationalTask(),
127+
"feature-extraction": new Sambanova.SambanovaFeatureExtractionTask(),
127128
},
128129
together: {
129130
"text-to-image": new Together.TogetherTextToImageTask(),

packages/inference/src/lib/getProviderModelId.ts

Lines changed: 0 additions & 74 deletions
This file was deleted.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { name as packageName, version as packageVersion } from "../../package.json";
22
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
33
import type { InferenceTask, Options, RequestArgs } from "../types";
4+
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping";
5+
import { getInferenceProviderMapping } from "./getInferenceProviderMapping";
46
import type { getProviderHelper } from "./getProviderHelper";
5-
import { getProviderModelId } from "./getProviderModelId";
67
import { isUrl } from "./isUrl";
78

89
/**
@@ -40,7 +41,13 @@ export async function makeRequestOptions(
4041

4142
if (args.endpointUrl) {
4243
// No need to have maybeModel, or to load default model for a task
43-
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
44+
return makeRequestOptionsFromResolvedModel(
45+
maybeModel ?? args.endpointUrl,
46+
providerHelper,
47+
args,
48+
undefined,
49+
options
50+
);
4451
}
4552

4653
if (!maybeModel && !task) {
@@ -54,16 +61,38 @@ export async function makeRequestOptions(
5461
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
5562
}
5663

57-
const resolvedModel = providerHelper.clientSideRoutingOnly
58-
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
59-
removeProviderPrefix(maybeModel!, provider)
60-
: await getProviderModelId({ model: hfModel, provider }, args, {
61-
task,
62-
fetch: options?.fetch,
63-
});
64+
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
65+
? ({
66+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
67+
providerId: removeProviderPrefix(maybeModel!, provider),
68+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
69+
hfModelId: maybeModel!,
70+
status: "live",
71+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
72+
task: task!,
73+
} satisfies InferenceProviderModelMapping)
74+
: await getInferenceProviderMapping(
75+
{
76+
modelId: hfModel,
77+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
78+
task: task!,
79+
provider,
80+
accessToken: args.accessToken,
81+
},
82+
{ fetch: options?.fetch }
83+
);
84+
if (!inferenceProviderMapping) {
85+
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
86+
}
6487

6588
// Use the sync version with the resolved model
66-
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
89+
return makeRequestOptionsFromResolvedModel(
90+
inferenceProviderMapping.providerId,
91+
providerHelper,
92+
args,
93+
inferenceProviderMapping,
94+
options
95+
);
6796
}
6897

6998
/**
@@ -77,6 +106,7 @@ export function makeRequestOptionsFromResolvedModel(
77106
data?: Blob | ArrayBuffer;
78107
stream?: boolean;
79108
},
109+
mapping: InferenceProviderModelMapping | undefined,
80110
options?: Options & {
81111
task?: InferenceTask;
82112
}
@@ -138,6 +168,7 @@ export function makeRequestOptionsFromResolvedModel(
138168
args: remainingArgs as Record<string, unknown>,
139169
model: resolvedModel,
140170
task,
171+
mapping,
141172
});
142173
/**
143174
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error

packages/inference/src/providers/consts.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
12
import type { InferenceProvider } from "../types";
23
import { type ModelId } from "../types";
34

4-
type ProviderId = string;
55
/**
66
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
77
* for a given Inference Provider,
88
* you can add it to the following dictionary, for dev purposes.
99
*
1010
* We also inject into this dictionary from tests.
1111
*/
12-
export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
12+
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
13+
InferenceProvider,
14+
Record<ModelId, InferenceProviderModelMapping>
15+
> = {
1316
/**
1417
* "HF model ID" => "Model ID on Inference Provider's side"
1518
*

0 commit comments

Comments
 (0)