Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/tasks-gen/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"type-fest": "^3.13.1"
},
"dependencies": {
"@huggingface/tasks": "workspace:^"
"@huggingface/tasks": "workspace:^",
"@huggingface/inference": "workspace:^"
}
}
3 changes: 3 additions & 0 deletions packages/tasks-gen/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 36 additions & 16 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { existsSync as pathExists } from "node:fs";
import * as fs from "node:fs/promises";
import * as path from "node:path/posix";

import type { InferenceSnippet } from "@huggingface/tasks";
import type { InferenceProvider, InferenceSnippet } from "@huggingface/tasks";
import { snippets } from "@huggingface/tasks";

type LANGUAGE = "sh" | "js" | "py";
Expand All @@ -28,6 +28,7 @@ const TEST_CASES: {
testName: string;
model: snippets.ModelDataMinimal;
languages: LANGUAGE[];
providers: InferenceProvider[];
opts?: Record<string, unknown>;
}[] = [
{
Expand All @@ -39,6 +40,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference", "replicate"],
opts: { streaming: false },
},
{
Expand All @@ -50,6 +52,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: true },
},
{
Expand All @@ -61,6 +64,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: false },
},
{
Expand All @@ -72,6 +76,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: true },
},
{
Expand All @@ -82,6 +87,7 @@ const TEST_CASES: {
tags: [],
inference: "",
},
providers: ["hf-inference"],
languages: ["sh", "js", "py"],
},
] as const;
Expand Down Expand Up @@ -113,31 +119,41 @@ function getFixtureFolder(testName: string): string {
function generateInferenceSnippet(
model: snippets.ModelDataMinimal,
language: LANGUAGE,
provider: InferenceProvider,
opts?: Record<string, unknown>
): InferenceSnippet[] {
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts);
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", provider, opts);
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets];
}

async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise<InferenceSnippet[]> {
async function getExpectedInferenceSnippet(
testName: string,
language: LANGUAGE,
provider: InferenceProvider
): Promise<InferenceSnippet[]> {
const fixtureFolder = getFixtureFolder(testName);
const files = await fs.readdir(fixtureFolder);

const expectedSnippets: InferenceSnippet[] = [];
for (const file of files.filter((file) => file.endsWith("." + language)).sort()) {
const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js"
for (const file of files.filter((file) => file.endsWith("." + language) && file.includes(`.${provider}.`)).sort()) {
const client = path.basename(file).split(".").slice(1, -2).join("."); // e.g. '0.huggingface.js.replicate.js' => "huggingface.js"
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" });
expectedSnippets.push(client === "default" ? { content } : { client, content });
}
return expectedSnippets;
}

async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) {
async function saveExpectedInferenceSnippet(
testName: string,
language: LANGUAGE,
provider: InferenceProvider,
snippets: InferenceSnippet[]
) {
const fixtureFolder = getFixtureFolder(testName);
await fs.mkdir(fixtureFolder, { recursive: true });

for (const [index, snippet] of snippets.entries()) {
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`);
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${provider}.${language}`);
await fs.writeFile(file, snippet.content);
}
}
Expand All @@ -147,13 +163,15 @@ if (import.meta.vitest) {
const { describe, expect, it } = import.meta.vitest;

describe("inference API snippets", () => {
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => {
describe(testName, () => {
languages.forEach((language) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language);
expect(generatedSnippets).toEqual(expectedSnippets);
providers.forEach((provider) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language, provider);
expect(generatedSnippets).toEqual(expectedSnippets);
});
});
});
});
Expand All @@ -166,11 +184,13 @@ if (import.meta.vitest) {
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });

console.debug(" 🏭 Generating new fixtures...");
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
console.debug(` ${testName} (${languages.join(", ")})`);
TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => {
console.debug(` ${testName} (${languages.join(", ")}) (${providers.join(", ")})`);
languages.forEach(async (language) => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
await saveExpectedInferenceSnippet(testName, language, generatedSnippets);
providers.forEach(async (provider) => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, opts);
await saveExpectedInferenceSnippet(testName, language, provider, generatedSnippets);
});
});
});
console.log("✅ All done!");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
curl 'https://huggingface.co/api/inference-proxy/replicate/v1/chat/completions' \
-H 'Authorization: Bearer api_token' \
-H 'Content-Type: application/json' \
--data '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"max_tokens": 500,
"stream": false
}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { HfInference } from "@huggingface/inference";

const client = new HfInference("api_token");

const chatCompletion = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
provider: "hf-inference",
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const chatCompletion = await client.chatCompletion({
content: "What is the capital of France?"
}
],
provider: "replicate",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="replicate",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://huggingface.co/api/inference-proxy/replicate",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from openai import OpenAI

client = OpenAI(
base_url="https://huggingface.co/api/inference-proxy/replicate",
api_key="api_token"
)

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const stream = client.chatCompletionStream({
content: "What is the capital of France?"
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
apiKey: "api_token"
});

let out = "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const chatCompletion = await client.chatCompletion({
]
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const stream = client.chatCompletionStream({
]
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
apiKey: "api_token"
});

let out = "";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { HfInference } from "@huggingface/inference";

const client = new HfInference("api_token");

const image = await client.textToImage({
model: "black-forest-labs/FLUX.1-schnell",
inputs: "Astronaut riding a horse",
parameters: { num_inference_steps: 5 },
provider: "hf-inference",
});
/// Use the generated image (it's a Blob)
Loading
Loading