Skip to content

Commit f416a61

Browse files
committed
Document python text to image snippets
1 parent bc7381c commit f416a61

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

packages/tasks/src/snippets/python.spec.ts

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import type { ModelDataMinimal } from "./types";
1+
import type { InferenceSnippet, ModelDataMinimal } from "./types";
22
import { describe, expect, it } from "vitest";
3-
import { snippetConversational } from "./python";
3+
import { snippetConversational, getPythonInferenceSnippet } from "./python";
44

55
describe("inference API snippets", () => {
66
it("conversational llm", async () => {
@@ -75,4 +75,39 @@ stream = client.chat.completions.create(
7575
for chunk in stream:
7676
print(chunk.choices[0].delta.content, end="")`);
7777
});
78+
79+
it("text-to-image", async () => {
80+
const model: ModelDataMinimal = {
81+
id: "black-forest-labs/FLUX.1-schnell",
82+
pipeline_tag: "text-to-image",
83+
tags: [],
84+
inference: "",
85+
};
86+
const snippets = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
87+
88+
expect(snippets.length).toEqual(2);
89+
90+
expect(snippets[0].client).toEqual("huggingface_hub");
91+
expect(snippets[0].content).toEqual(`from huggingface_hub import InferenceClient
92+
client = InferenceClient("black-forest-labs/FLUX.1-schnell", token="api_token")
93+
# output is a PIL.Image object
94+
image = client.text_to_image("Astronaut riding a horse")`);
95+
96+
expect(snippets[1].client).toEqual("requests");
97+
expect(snippets[1].content).toEqual(`import requests
98+
99+
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
100+
headers = {"Authorization": "Bearer api_token"}
101+
102+
def query(payload):
103+
response = requests.post(API_URL, headers=headers, json=payload)
104+
return response.content
105+
image_bytes = query({
106+
"inputs": "Astronaut riding a horse",
107+
})
108+
# You can access the image with PIL.Image for example
109+
import io
110+
from PIL import Image
111+
image = Image.open(io.BytesIO(image_bytes))`);
112+
});
78113
});

packages/tasks/src/snippets/python.ts

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44
import { getModelInputSnippet } from "./inputs.js";
55
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
66

7+
const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8+
`from huggingface_hub import InferenceClient
9+
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")`;
10+
711
export const snippetConversational = (
812
model: ModelDataMinimal,
913
accessToken: string,
@@ -161,8 +165,16 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
161165
output = query(${getModelInputSnippet(model)})`,
162166
});
163167

164-
export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
165-
content: `def query(payload):
168+
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [
169+
{
170+
client: "huggingface_hub",
171+
content: `${snippetImportInferenceClient(model, accessToken)}
172+
# output is a PIL.Image object
173+
image = client.text_to_image(${getModelInputSnippet(model)})`,
174+
},
175+
{
176+
client: "requests",
177+
content: `def query(payload):
166178
response = requests.post(API_URL, headers=headers, json=payload)
167179
return response.content
168180
image_bytes = query({
@@ -172,7 +184,8 @@ image_bytes = query({
172184
import io
173185
from PIL import Image
174186
image = Image.open(io.BytesIO(image_bytes))`,
175-
});
187+
},
188+
];
176189

177190
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
178191
content: `def query(payload):
@@ -288,12 +301,14 @@ export function getPythonInferenceSnippet(
288301
return snippets.map((snippet) => {
289302
return {
290303
...snippet,
291-
content: `import requests
304+
content: snippet.content.includes("requests")
305+
? `import requests
292306
293307
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
294308
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
295309
296-
${snippet.content}`,
310+
${snippet.content}`
311+
: snippet.content,
297312
};
298313
});
299314
}

0 commit comments

Comments
 (0)