Skip to content

Commit 4ea6d32

Browse files
committed
Merge remote-tracking branch 'origin/main' into extension-fix
2 parents 31c8f4f + a626ee5 commit 4ea6d32

File tree

7 files changed

+197
-21
lines changed

7 files changed

+197
-21
lines changed

packages/tasks/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@huggingface/tasks",
33
"packageManager": "[email protected]",
4-
"version": "0.13.0",
4+
"version": "0.13.1",
55
"description": "List of ML tasks for huggingface.co/tasks",
66
"repository": "https://github.com/huggingface/huggingface.js.git",
77
"publishConfig": {

packages/tasks/src/snippets/curl.spec.ts

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { ModelDataMinimal } from "./types.js";
22
import { describe, expect, it } from "vitest";
3-
import { snippetTextGeneration } from "./curl.js";
3+
import { getCurlInferenceSnippet } from "./curl.js";
44

55
describe("inference API snippets", () => {
66
it("conversational llm", async () => {
@@ -10,7 +10,7 @@ describe("inference API snippets", () => {
1010
tags: ["conversational"],
1111
inference: "",
1212
};
13-
const snippet = snippetTextGeneration(model, "api_token");
13+
const snippet = getCurlInferenceSnippet(model, "api_token");
1414

1515
expect(snippet.content)
1616
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \\
@@ -29,14 +29,40 @@ describe("inference API snippets", () => {
2929
}'`);
3030
});
3131

32+
it("conversational llm non-streaming", async () => {
33+
const model: ModelDataMinimal = {
34+
id: "meta-llama/Llama-3.1-8B-Instruct",
35+
pipeline_tag: "text-generation",
36+
tags: ["conversational"],
37+
inference: "",
38+
};
39+
const snippet = getCurlInferenceSnippet(model, "api_token", { streaming: false });
40+
41+
expect(snippet.content)
42+
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \\
43+
-H "Authorization: Bearer api_token" \\
44+
-H 'Content-Type: application/json' \\
45+
--data '{
46+
"model": "meta-llama/Llama-3.1-8B-Instruct",
47+
"messages": [
48+
{
49+
"role": "user",
50+
"content": "What is the capital of France?"
51+
}
52+
],
53+
"max_tokens": 500,
54+
"stream": false
55+
}'`);
56+
});
57+
3258
it("conversational vlm", async () => {
3359
const model: ModelDataMinimal = {
3460
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
3561
pipeline_tag: "image-text-to-text",
3662
tags: ["conversational"],
3763
inference: "",
3864
};
39-
const snippet = snippetTextGeneration(model, "api_token");
65+
const snippet = getCurlInferenceSnippet(model, "api_token");
4066

4167
expect(snippet.content)
4268
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \\

packages/tasks/src/snippets/curl.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,13 @@ export const curlSnippets: Partial<
105105
"image-segmentation": snippetFile,
106106
};
107107

108-
export function getCurlInferenceSnippet(model: ModelDataMinimal, accessToken: string): InferenceSnippet {
108+
export function getCurlInferenceSnippet(
109+
model: ModelDataMinimal,
110+
accessToken: string,
111+
opts?: Record<string, unknown>
112+
): InferenceSnippet {
109113
return model.pipeline_tag && model.pipeline_tag in curlSnippets
110-
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
114+
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
111115
: { content: "" };
112116
}
113117

packages/tasks/src/snippets/js.spec.ts

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
22
import { describe, expect, it } from "vitest";
3-
import { snippetTextGeneration } from "./js.js";
3+
import { getJsInferenceSnippet } from "./js.js";
44

55
describe("inference API snippets", () => {
66
it("conversational llm", async () => {
@@ -10,7 +10,7 @@ describe("inference API snippets", () => {
1010
tags: ["conversational"],
1111
inference: "",
1212
};
13-
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];
13+
const snippet = getJsInferenceSnippet(model, "api_token") as InferenceSnippet[];
1414

1515
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
1616
@@ -38,14 +38,41 @@ for await (const chunk of stream) {
3838
}`);
3939
});
4040

41+
it("conversational llm non-streaming", async () => {
42+
const model: ModelDataMinimal = {
43+
id: "meta-llama/Llama-3.1-8B-Instruct",
44+
pipeline_tag: "text-generation",
45+
tags: ["conversational"],
46+
inference: "",
47+
};
48+
const snippet = getJsInferenceSnippet(model, "api_token", { streaming: false }) as InferenceSnippet[];
49+
50+
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
51+
52+
const client = new HfInference("api_token")
53+
54+
const chatCompletion = await client.chatCompletion({
55+
model: "meta-llama/Llama-3.1-8B-Instruct",
56+
messages: [
57+
{
58+
role: "user",
59+
content: "What is the capital of France?"
60+
}
61+
],
62+
max_tokens: 500
63+
});
64+
65+
console.log(chatCompletion.choices[0].message);`);
66+
});
67+
4168
it("conversational vlm", async () => {
4269
const model: ModelDataMinimal = {
4370
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
4471
pipeline_tag: "image-text-to-text",
4572
tags: ["conversational"],
4673
inference: "",
4774
};
48-
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];
75+
const snippet = getJsInferenceSnippet(model, "api_token") as InferenceSnippet[];
4976

5077
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
5178
@@ -75,6 +102,41 @@ const stream = client.chatCompletionStream({
75102
max_tokens: 500
76103
});
77104
105+
for await (const chunk of stream) {
106+
if (chunk.choices && chunk.choices.length > 0) {
107+
const newContent = chunk.choices[0].delta.content;
108+
out += newContent;
109+
console.log(newContent);
110+
}
111+
}`);
112+
});
113+
114+
it("conversational llm", async () => {
115+
const model: ModelDataMinimal = {
116+
id: "meta-llama/Llama-3.1-8B-Instruct",
117+
pipeline_tag: "text-generation",
118+
tags: ["conversational"],
119+
inference: "",
120+
};
121+
const snippet = getJsInferenceSnippet(model, "api_token") as InferenceSnippet[];
122+
123+
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
124+
125+
const client = new HfInference("api_token")
126+
127+
let out = "";
128+
129+
const stream = client.chatCompletionStream({
130+
model: "meta-llama/Llama-3.1-8B-Instruct",
131+
messages: [
132+
{
133+
role: "user",
134+
content: "What is the capital of France?"
135+
}
136+
],
137+
max_tokens: 500
138+
});
139+
78140
for await (const chunk of stream) {
79141
if (chunk.choices && chunk.choices.length > 0) {
80142
const newContent = chunk.choices[0].delta.content;

packages/tasks/src/snippets/js.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ for await (const chunk of stream) {
109109
return [
110110
{
111111
client: "huggingface.js",
112-
content: `import { HfInference } from '@huggingface/inference'
112+
content: `import { HfInference } from "@huggingface/inference"
113113
114114
const client = new HfInference("${accessToken || `{API_TOKEN}`}")
115115
@@ -292,10 +292,11 @@ export const jsSnippets: Partial<
292292

293293
export function getJsInferenceSnippet(
294294
model: ModelDataMinimal,
295-
accessToken: string
295+
accessToken: string,
296+
opts?: Record<string, unknown>
296297
): InferenceSnippet | InferenceSnippet[] {
297298
return model.pipeline_tag && model.pipeline_tag in jsSnippets
298-
? jsSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
299+
? jsSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
299300
: { content: "" };
300301
}
301302

packages/tasks/src/snippets/python.spec.ts

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import type { ModelDataMinimal } from "./types.js";
1+
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
22
import { describe, expect, it } from "vitest";
3-
import { snippetConversational } from "./python.js";
3+
import { getPythonInferenceSnippet } from "./python.js";
44

55
describe("inference API snippets", () => {
66
it("conversational llm", async () => {
@@ -10,7 +10,7 @@ describe("inference API snippets", () => {
1010
tags: ["conversational"],
1111
inference: "",
1212
};
13-
const snippet = snippetConversational(model, "api_token");
13+
const snippet = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
1414

1515
expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient
1616
@@ -34,14 +34,43 @@ for chunk in stream:
3434
print(chunk.choices[0].delta.content, end="")`);
3535
});
3636

37+
it("conversational llm non-streaming", async () => {
38+
const model: ModelDataMinimal = {
39+
id: "meta-llama/Llama-3.1-8B-Instruct",
40+
pipeline_tag: "text-generation",
41+
tags: ["conversational"],
42+
inference: "",
43+
};
44+
const snippet = getPythonInferenceSnippet(model, "api_token", { streaming: false }) as InferenceSnippet[];
45+
46+
expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient
47+
48+
client = InferenceClient(api_key="api_token")
49+
50+
messages = [
51+
{
52+
"role": "user",
53+
"content": "What is the capital of France?"
54+
}
55+
]
56+
57+
completion = client.chat.completions.create(
58+
model="meta-llama/Llama-3.1-8B-Instruct",
59+
messages=messages,
60+
max_tokens=500
61+
)
62+
63+
print(completion.choices[0].message)`);
64+
});
65+
3766
it("conversational vlm", async () => {
3867
const model: ModelDataMinimal = {
3968
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
4069
pipeline_tag: "image-text-to-text",
4170
tags: ["conversational"],
4271
inference: "",
4372
};
44-
const snippet = snippetConversational(model, "api_token");
73+
const snippet = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
4574

4675
expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient
4776
@@ -75,4 +104,41 @@ stream = client.chat.completions.create(
75104
for chunk in stream:
76105
print(chunk.choices[0].delta.content, end="")`);
77106
});
107+
108+
it("text-to-image", async () => {
109+
const model: ModelDataMinimal = {
110+
id: "black-forest-labs/FLUX.1-schnell",
111+
pipeline_tag: "text-to-image",
112+
tags: [],
113+
inference: "",
114+
};
115+
const snippets = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
116+
117+
expect(snippets.length).toEqual(2);
118+
119+
expect(snippets[0].client).toEqual("huggingface_hub");
120+
expect(snippets[0].content).toEqual(`from huggingface_hub import InferenceClient
121+
client = InferenceClient("black-forest-labs/FLUX.1-schnell", token="api_token")
122+
123+
# output is a PIL.Image object
124+
image = client.text_to_image("Astronaut riding a horse")`);
125+
126+
expect(snippets[1].client).toEqual("requests");
127+
expect(snippets[1].content).toEqual(`import requests
128+
129+
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
130+
headers = {"Authorization": "Bearer api_token"}
131+
132+
def query(payload):
133+
response = requests.post(API_URL, headers=headers, json=payload)
134+
return response.content
135+
image_bytes = query({
136+
"inputs": "Astronaut riding a horse",
137+
})
138+
139+
# You can access the image with PIL.Image for example
140+
import io
141+
from PIL import Image
142+
image = Image.open(io.BytesIO(image_bytes))`);
143+
});
78144
});

packages/tasks/src/snippets/python.ts

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44
import { getModelInputSnippet } from "./inputs.js";
55
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
66

7+
const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8+
`from huggingface_hub import InferenceClient
9+
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
10+
`;
11+
712
export const snippetConversational = (
813
model: ModelDataMinimal,
914
accessToken: string,
@@ -161,18 +166,28 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
161166
output = query(${getModelInputSnippet(model)})`,
162167
});
163168

164-
export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
165-
content: `def query(payload):
169+
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [
170+
{
171+
client: "huggingface_hub",
172+
content: `${snippetImportInferenceClient(model, accessToken)}
173+
# output is a PIL.Image object
174+
image = client.text_to_image(${getModelInputSnippet(model)})`,
175+
},
176+
{
177+
client: "requests",
178+
content: `def query(payload):
166179
response = requests.post(API_URL, headers=headers, json=payload)
167180
return response.content
168181
image_bytes = query({
169182
"inputs": ${getModelInputSnippet(model)},
170183
})
184+
171185
# You can access the image with PIL.Image for example
172186
import io
173187
from PIL import Image
174188
image = Image.open(io.BytesIO(image_bytes))`,
175-
});
189+
},
190+
];
176191

177192
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
178193
content: `def query(payload):
@@ -288,12 +303,14 @@ export function getPythonInferenceSnippet(
288303
return snippets.map((snippet) => {
289304
return {
290305
...snippet,
291-
content: `import requests
306+
content: snippet.content.includes("requests")
307+
? `import requests
292308
293309
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
294310
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
295311
296-
${snippet.content}`,
312+
${snippet.content}`
313+
: snippet.content,
297314
};
298315
});
299316
}

0 commit comments

Comments
 (0)