diff --git a/packages/tasks/src/snippets/python.spec.ts b/packages/tasks/src/snippets/python.spec.ts index 1e502ea5f6..1bae17cb3d 100644 --- a/packages/tasks/src/snippets/python.spec.ts +++ b/packages/tasks/src/snippets/python.spec.ts @@ -104,4 +104,41 @@ stream = client.chat.completions.create( for chunk in stream: print(chunk.choices[0].delta.content, end="")`); }); + + it("text-to-image", async () => { + const model: ModelDataMinimal = { + id: "black-forest-labs/FLUX.1-schnell", + pipeline_tag: "text-to-image", + tags: [], + inference: "", + }; + const snippets = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[]; + + expect(snippets.length).toEqual(2); + + expect(snippets[0].client).toEqual("huggingface_hub"); + expect(snippets[0].content).toEqual(`from huggingface_hub import InferenceClient +client = InferenceClient("black-forest-labs/FLUX.1-schnell", token="api_token") + +# output is a PIL.Image object +image = client.text_to_image("Astronaut riding a horse")`); + + expect(snippets[1].client).toEqual("requests"); + expect(snippets[1].content).toEqual(`import requests + +API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" +headers = {"Authorization": "Bearer api_token"} + +def query(payload): + response = requests.post(API_URL, headers=headers, json=payload) + return response.content +image_bytes = query({ + "inputs": "Astronaut riding a horse", +}) + +# You can access the image with PIL.Image for example +import io +from PIL import Image +image = Image.open(io.BytesIO(image_bytes))`); + }); }); diff --git a/packages/tasks/src/snippets/python.ts b/packages/tasks/src/snippets/python.ts index 31ce47a10b..bdb148e391 100644 --- a/packages/tasks/src/snippets/python.ts +++ b/packages/tasks/src/snippets/python.ts @@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; +const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string => + `from huggingface_hub import InferenceClient +client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}") +`; + export const snippetConversational = ( model: ModelDataMinimal, accessToken: string, @@ -161,18 +166,28 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({ output = query(${getModelInputSnippet(model)})`, }); -export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): +export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [ + { + client: "huggingface_hub", + content: `${snippetImportInferenceClient(model, accessToken)} +# output is a PIL.Image object +image = client.text_to_image(${getModelInputSnippet(model)})`, + }, + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content image_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) + # You can access the image with PIL.Image for example import io from PIL import Image image = Image.open(io.BytesIO(image_bytes))`, -}); + }, +]; export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({ content: `def query(payload): @@ -288,12 +303,14 @@ export function getPythonInferenceSnippet( return snippets.map((snippet) => { return { ...snippet, - content: `import requests + content: snippet.content.includes("requests") + ? `import requests API_URL = "https://api-inference.huggingface.co/models/${model.id}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} -${snippet.content}`, +${snippet.content}` + : snippet.content, }; }); }