Skip to content
42 changes: 42 additions & 0 deletions .github/workflows/inference-check-snippets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Inference check snippets
on:
pull_request:
paths:
- "packages/tasks/src/snippets/**"
- ".github/workflows/inference-check-snippets.yml"

jobs:
check-snippets:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- run: corepack enable

- uses: actions/setup-node@v3
with:
node-version: "20"
cache: "pnpm"
cache-dependency-path: "**/pnpm-lock.yaml"
- run: |
cd packages/tasks
pnpm install

# TODO: Find a way to run on all pipeline tags
# TODO: print snippet only if it has changed since the last commit on main (?)
# TODO: (even better: automated message on the PR with diff)
- name: Print text-to-image snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-to-image"

- name: Print simple text-generation snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-generation"

- name: Print conversational text-generation snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational"
3 changes: 2 additions & 1 deletion packages/tasks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"check": "tsc",
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts",
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json",
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json"
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json",
"check-snippets": "tsx scripts/check-snippets.ts"
},
"type": "module",
"files": [
Expand Down
55 changes: 55 additions & 0 deletions packages/tasks/scripts/check-snippets.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
Copy link
Collaborator

@mishig25 mishig25 Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented tests in https://github.com/huggingface/huggingface.js/pull/1003/files using the vite test convention of xyz.spec.ts files (which run on pnpm test).

I think we should just put more tests into xyz.spec.ts files rather than creating a custom mechanism of check-snippet.ts & inference-check-snippets.yml. Or am I missing some necessary details?

Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with this solution because I was bored about how difficult it was to debug things locally but automated tests feel much more natural now that you mention it ^^

* Generates inference snippets as they would be shown on the Hub for Curl, JS and Python.
* Snippets will only be printed to the terminal to make it easier to debug when making changes to the snippets.
*
* Usage:
* pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational"
* pnpm run check-snippets --pipeline-tag="image-text-to-text" --tags="conversational"
* pnpm run check-snippets --pipeline-tag="text-to-image"
*
* This script is meant only for debug purposes.
*/
import { python, curl, js } from "../src/snippets/index";
import type { InferenceSnippet, ModelDataMinimal } from "../src/snippets/types";
import type { PipelineType } from "../src/pipelines";

// Parse command-line arguments
const args = process.argv.slice(2).reduce(
(acc, arg) => {
const [key, value] = arg.split("=");
acc[key.replace("--", "")] = value;
return acc;
},
{} as { [key: string]: string }
);

const accessToken = "hf_**********";
const pipelineTag = (args["pipeline-tag"] || "text-generation") as PipelineType;
const tags = (args["tags"] || "").split(",");

const modelMinimal: ModelDataMinimal = {
id: "llama-6-1720B-Instruct",
pipeline_tag: pipelineTag,
tags: tags,
inference: "****",
};

const printSnippets = (snippets: InferenceSnippet | InferenceSnippet[], language: string) => {
const snippetArray = Array.isArray(snippets) ? snippets : [snippets];
snippetArray.forEach((snippet) => {
console.log(`\n\x1b[33m${language} ${snippet.client}\x1b[0m`);
console.log(`\n\`\`\`${language === "JS" ? "js" : language.toLowerCase()}\n${snippet.content}\n\`\`\`\n`);
});
};

const generateAndPrintSnippets = (
generator: (model: ModelDataMinimal, token: string) => InferenceSnippet | InferenceSnippet[],
language: string
) => {
const snippets = generator(modelMinimal, accessToken);
printSnippets(snippets, language);
};

generateAndPrintSnippets(curl.getCurlInferenceSnippet, "Curl");
generateAndPrintSnippets(python.getPythonInferenceSnippet, "Python");
generateAndPrintSnippets(js.getJsInferenceSnippet, "JS");
27 changes: 24 additions & 3 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient

client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")`;

export const snippetConversational = (
model: ModelDataMinimal,
accessToken: string,
Expand Down Expand Up @@ -168,18 +173,31 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
output = query(${getModelInputSnippet(model)})`,
});

export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content

image_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`,
});
},
{
client: "huggingface_hub",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's agree on the convention of putting huggingface_hub as the first item in the list

content: `${snippetImportInferenceClient(model, accessToken)}

# output is a PIL.Image object
image = client.text_to_image(${getModelInputSnippet(model)})`,
},
];
};

export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
Expand Down Expand Up @@ -284,6 +302,9 @@ export function getPythonInferenceSnippet(
if (model.tags.includes("conversational")) {
// Conversational model detected, so we display a code snippet that features the Messages API
return snippetConversational(model, accessToken, opts);
} else if (model.pipeline_tag == "text-to-image") {
// TODO: factorize this logic
return snippetTextToImage(model, accessToken);
} else {
let snippets =
model.pipeline_tag && model.pipeline_tag in pythonSnippets
Expand Down
Loading