diff --git a/packages/agents/pnpm-lock.yaml b/packages/agents/pnpm-lock.yaml index 060aacb353..455c7460ab 100644 --- a/packages/agents/pnpm-lock.yaml +++ b/packages/agents/pnpm-lock.yaml @@ -7,7 +7,7 @@ settings: dependencies: '@huggingface/inference': specifier: ^2.6.1 - version: link:../inference + version: 2.8.1 devDependencies: '@types/node': @@ -16,6 +16,17 @@ devDependencies: packages: + /@huggingface/inference@2.8.1: + resolution: {integrity: sha512-EfsNtY9OR6JCNaUa5bZu2mrs48iqeTz0Gutwf+fU0Kypx33xFQB4DKMhp8u4Ee6qVbLbNWvTHuWwlppLQl4p4Q==} + engines: {node: '>=18'} + dependencies: + '@huggingface/tasks': 0.12.30 + dev: false + + /@huggingface/tasks@0.12.30: + resolution: {integrity: sha512-A1ITdxbEzx9L8wKR8pF7swyrTLxWNDFIGDLUWInxvks2ruQ8PLRBZe8r0EcjC3CDdtlj9jV1V4cgV35K/iy3GQ==} + dev: false + /@types/node@18.13.0: resolution: {integrity: sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==} dev: true diff --git a/packages/inference/README.md b/packages/inference/README.md index 4cc2e68818..d054e1686a 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -1,4 +1,4 @@ -# 🤗 Hugging Face Inference Endpoints +# 🤗 Hugging Face Inference A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers. It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers. diff --git a/packages/tasks-gen/package.json b/packages/tasks-gen/package.json index 23ad43cfc8..81f836b6f0 100644 --- a/packages/tasks-gen/package.json +++ b/packages/tasks-gen/package.json @@ -26,6 +26,7 @@ "type-fest": "^3.13.1" }, "dependencies": { - "@huggingface/tasks": "workspace:^" + "@huggingface/tasks": "workspace:^", + "@huggingface/inference": "workspace:^" } } diff --git a/packages/tasks-gen/pnpm-lock.yaml b/packages/tasks-gen/pnpm-lock.yaml index f754b702eb..bedd30656e 100644 --- a/packages/tasks-gen/pnpm-lock.yaml +++ b/packages/tasks-gen/pnpm-lock.yaml @@ -5,6 +5,9 @@ settings: excludeLinksFromLockfile: false dependencies: + '@huggingface/inference': + specifier: workspace:^ + version: link:../inference '@huggingface/tasks': specifier: workspace:^ version: link:../tasks diff --git a/packages/tasks-gen/scripts/generate-snippets-fixtures.ts b/packages/tasks-gen/scripts/generate-snippets-fixtures.ts index 1107d78a50..d6a3a61ab7 100644 --- a/packages/tasks-gen/scripts/generate-snippets-fixtures.ts +++ b/packages/tasks-gen/scripts/generate-snippets-fixtures.ts @@ -19,7 +19,7 @@ import { existsSync as pathExists } from "node:fs"; import * as fs from "node:fs/promises"; import * as path from "node:path/posix"; -import type { InferenceSnippet } from "@huggingface/tasks"; +import type { InferenceProvider, InferenceSnippet } from "@huggingface/tasks"; import { snippets } from "@huggingface/tasks"; type LANGUAGE = "sh" | "js" | "py"; @@ -28,6 +28,7 @@ const TEST_CASES: { testName: string; model: snippets.ModelDataMinimal; languages: LANGUAGE[]; + providers: InferenceProvider[]; opts?: Record; }[] = [ { @@ -39,6 +40,7 @@ const TEST_CASES: { inference: "", }, languages: ["sh", "js", "py"], + providers: ["hf-inference", "together"], opts: { streaming: false }, }, { @@ -50,6 +52,7 @@ const TEST_CASES: { inference: "", }, languages: ["sh", "js", "py"], + providers: ["hf-inference"], opts: { streaming: true }, }, { @@ -61,6 +64,7 @@ const TEST_CASES: { inference: "", }, languages: ["sh", "js", "py"], + providers: ["hf-inference"], opts: { streaming: false }, }, { @@ -72,6 +76,7 @@ const TEST_CASES: { inference: "", }, languages: ["sh", "js", "py"], + providers: ["hf-inference"], opts: { streaming: true }, }, { @@ -82,6 +87,7 @@ const TEST_CASES: { tags: [], inference: "", }, + providers: ["hf-inference"], languages: ["sh", "js", "py"], }, ] as const; @@ -113,31 +119,41 @@ function getFixtureFolder(testName: string): string { function generateInferenceSnippet( model: snippets.ModelDataMinimal, language: LANGUAGE, + provider: InferenceProvider, opts?: Record ): InferenceSnippet[] { - const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts); + const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", provider, opts); return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets]; } -async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise { +async function getExpectedInferenceSnippet( + testName: string, + language: LANGUAGE, + provider: InferenceProvider +): Promise { const fixtureFolder = getFixtureFolder(testName); const files = await fs.readdir(fixtureFolder); const expectedSnippets: InferenceSnippet[] = []; - for (const file of files.filter((file) => file.endsWith("." + language)).sort()) { - const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js" + for (const file of files.filter((file) => file.endsWith("." + language) && file.includes(`.${provider}.`)).sort()) { + const client = path.basename(file).split(".").slice(1, -2).join("."); // e.g. '0.huggingface.js.replicate.js' => "huggingface.js" const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" }); - expectedSnippets.push(client === "default" ? { content } : { client, content }); + expectedSnippets.push({ client, content }); } return expectedSnippets; } -async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) { +async function saveExpectedInferenceSnippet( + testName: string, + language: LANGUAGE, + provider: InferenceProvider, + snippets: InferenceSnippet[] +) { const fixtureFolder = getFixtureFolder(testName); await fs.mkdir(fixtureFolder, { recursive: true }); for (const [index, snippet] of snippets.entries()) { - const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`); + const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${provider}.${language}`); await fs.writeFile(file, snippet.content); } } @@ -147,13 +163,15 @@ if (import.meta.vitest) { const { describe, expect, it } = import.meta.vitest; describe("inference API snippets", () => { - TEST_CASES.forEach(({ testName, model, languages, opts }) => { + TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => { describe(testName, () => { languages.forEach((language) => { - it(language, async () => { - const generatedSnippets = generateInferenceSnippet(model, language, opts); - const expectedSnippets = await getExpectedInferenceSnippet(testName, language); - expect(generatedSnippets).toEqual(expectedSnippets); + providers.forEach((provider) => { + it(language, async () => { + const generatedSnippets = generateInferenceSnippet(model, language, provider, opts); + const expectedSnippets = await getExpectedInferenceSnippet(testName, language, provider); + expect(generatedSnippets).toEqual(expectedSnippets); + }); }); }); }); @@ -166,11 +184,13 @@ if (import.meta.vitest) { await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true }); console.debug(" 🏭 Generating new fixtures..."); - TEST_CASES.forEach(({ testName, model, languages, opts }) => { - console.debug(` ${testName} (${languages.join(", ")})`); + TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => { + console.debug(` ${testName} (${languages.join(", ")}) (${providers.join(", ")})`); languages.forEach(async (language) => { - const generatedSnippets = generateInferenceSnippet(model, language, opts); - await saveExpectedInferenceSnippet(testName, language, generatedSnippets); + providers.forEach(async (provider) => { + const generatedSnippets = generateInferenceSnippet(model, language, provider, opts); + await saveExpectedInferenceSnippet(testName, language, provider, generatedSnippets); + }); }); }); console.log("✅ All done!"); diff --git a/packages/tasks-gen/scripts/inference-codegen.ts b/packages/tasks-gen/scripts/inference-codegen.ts index e9fb05b804..349b2870a2 100644 --- a/packages/tasks-gen/scripts/inference-codegen.ts +++ b/packages/tasks-gen/scripts/inference-codegen.ts @@ -147,10 +147,10 @@ async function generateBinaryInputTypes( const propName = propSignature.name.getText(tsSource); const propIsMedia = - typeof spec["properties"] !== "string" && - typeof spec["properties"]?.[propName] !== "string" && - typeof spec["properties"]?.[propName]?.["comment"] === "string" - ? !!spec["properties"]?.[propName]?.["comment"]?.includes("type=binary") + // eslint-disable-next-line @typescript-eslint/no-explicit-any + typeof (spec as any)["properties"]?.[propName]?.["comment"] === "string" + ? // eslint-disable-next-line @typescript-eslint/no-explicit-any + !!(spec as any)["properties"][propName]["comment"].includes("type=binary") : false; if (!propIsMedia) { return; diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.default.sh b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.hf-inference.sh similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.default.sh rename to packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.hf-inference.sh diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.together.sh b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.together.sh new file mode 100644 index 0000000000..4fc849c059 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.curl.together.sh @@ -0,0 +1,14 @@ +curl 'https://huggingface.co/api/inference-proxy/together/v1/chat/completions' \ +-H 'Authorization: Bearer api_token' \ +-H 'Content-Type: application/json' \ +--data '{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + "max_tokens": 500, + "stream": false +}' \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.hf-inference.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.hf-inference.js new file mode 100644 index 0000000000..0bb99c2c8c --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.hf-inference.js @@ -0,0 +1,17 @@ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("api_token"); + +const chatCompletion = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [ + { + role: "user", + content: "What is the capital of France?" + } + ], + provider: "hf-inference", + max_tokens: 500 +}); + +console.log(chatCompletion.choices[0].message); \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.together.js similarity index 93% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.js rename to packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.together.js index c9243e1792..fb00ca5d8d 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface.js.together.js @@ -10,6 +10,7 @@ const chatCompletion = await client.chatCompletion({ content: "What is the capital of France?" } ], + provider: "together", max_tokens: 500 }); diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.hf-inference.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.hf-inference.py new file mode 100644 index 0000000000..a4e9d17d69 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.hf-inference.py @@ -0,0 +1,21 @@ +from huggingface_hub import InferenceClient + +client = InferenceClient( + provider="hf-inference", + api_key="api_token" +) + +messages = [ + { + "role": "user", + "content": "What is the capital of France?" + } +] + +completion = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + max_tokens=500 +) + +print(completion.choices[0].message) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.together.py similarity index 80% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.py rename to packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.together.py index e60e63114c..7464a8b221 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.py +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/0.huggingface_hub.together.py @@ -1,6 +1,9 @@ from huggingface_hub import InferenceClient -client = InferenceClient(api_key="api_token") +client = InferenceClient( + provider="together", + api_key="api_token" +) messages = [ { diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.js similarity index 80% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.js rename to packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.js index ddccf8d502..63721ce2d5 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.js @@ -1,8 +1,8 @@ import { OpenAI } from "openai"; const client = new OpenAI({ - baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "api_token" + baseURL: "https://api-inference.huggingface.co/v1/", + apiKey: "api_token" }); const chatCompletion = await client.chat.completions.create({ diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.py similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.py rename to packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.hf-inference.py diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.js new file mode 100644 index 0000000000..9dc9b1f2bf --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.js @@ -0,0 +1,19 @@ +import { OpenAI } from "openai"; + +const client = new OpenAI({ + baseURL: "https://huggingface.co/api/inference-proxy/together", + apiKey: "api_token" +}); + +const chatCompletion = await client.chat.completions.create({ + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [ + { + role: "user", + content: "What is the capital of France?" + } + ], + max_tokens: 500 +}); + +console.log(chatCompletion.choices[0].message); \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.py new file mode 100644 index 0000000000..ee179ef368 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.openai.together.py @@ -0,0 +1,21 @@ +from openai import OpenAI + +client = OpenAI( + base_url="https://huggingface.co/api/inference-proxy/together", + api_key="api_token" +) + +messages = [ + { + "role": "user", + "content": "What is the capital of France?" + } +] + +completion = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + max_tokens=500 +) + +print(completion.choices[0].message) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.default.sh b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.curl.hf-inference.sh similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.default.sh rename to packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.curl.hf-inference.sh diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.hf-inference.js similarity index 94% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.js rename to packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.hf-inference.js index 581e0a3e8a..86a6456704 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface.js.hf-inference.js @@ -12,6 +12,7 @@ const stream = client.chatCompletionStream({ content: "What is the capital of France?" } ], + provider: "hf-inference", max_tokens: 500 }); diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.hf-inference.py similarity index 82% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.py rename to packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.hf-inference.py index 38a5efcd6e..a9f9567d0c 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.py +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/0.huggingface_hub.hf-inference.py @@ -1,6 +1,9 @@ from huggingface_hub import InferenceClient -client = InferenceClient(api_key="api_token") +client = InferenceClient( + provider="hf-inference", + api_key="api_token" +) messages = [ { diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.js similarity index 95% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.js rename to packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.js index ccb3cb15b1..6bc97ab95d 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.js @@ -2,7 +2,7 @@ import { OpenAI } from "openai"; const client = new OpenAI({ baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "api_token" + apiKey: "api_token" }); let out = ""; diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.py similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.py rename to packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.openai.hf-inference.py diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.default.sh b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.curl.hf-inference.sh similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.default.sh rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.curl.hf-inference.sh diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.js b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.hf-inference.js similarity index 95% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.js rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.hf-inference.js index b7e54db67c..5c7c693350 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface.js.hf-inference.js @@ -21,6 +21,7 @@ const chatCompletion = await client.chatCompletion({ ] } ], + provider: "hf-inference", max_tokens: 500 }); diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.py b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.hf-inference.py similarity index 87% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.py rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.hf-inference.py index 82c2389972..5dc7d45c94 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.py +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/0.huggingface_hub.hf-inference.py @@ -1,6 +1,9 @@ from huggingface_hub import InferenceClient -client = InferenceClient(api_key="api_token") +client = InferenceClient( + provider="hf-inference", + api_key="api_token" +) messages = [ { diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.js b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.js similarity index 87% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.js rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.js index 6badefd525..cbfddb53a6 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.js @@ -1,8 +1,8 @@ import { OpenAI } from "openai"; const client = new OpenAI({ - baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "api_token" + baseURL: "https://api-inference.huggingface.co/v1/", + apiKey: "api_token" }); const chatCompletion = await client.chat.completions.create({ diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.py b/packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.py similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.py rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.openai.hf-inference.py diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.default.sh b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.curl.hf-inference.sh similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.default.sh rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.curl.hf-inference.sh diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.js b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.hf-inference.js similarity index 96% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.js rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.hf-inference.js index e91f14d81c..bcab793db8 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface.js.hf-inference.js @@ -23,6 +23,7 @@ const stream = client.chatCompletionStream({ ] } ], + provider: "hf-inference", max_tokens: 500 }); diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.py b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.hf-inference.py similarity index 88% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.py rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.hf-inference.py index 9eaf7a1677..d154c80902 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.py +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/0.huggingface_hub.hf-inference.py @@ -1,6 +1,9 @@ from huggingface_hub import InferenceClient -client = InferenceClient(api_key="api_token") +client = InferenceClient( + provider="hf-inference", + api_key="api_token" +) messages = [ { diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.js b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.hf-inference.js similarity index 97% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.js rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.hf-inference.js index 59447faa05..c947342b3e 100644 --- a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.js +++ b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.hf-inference.js @@ -2,7 +2,7 @@ import { OpenAI } from "openai"; const client = new OpenAI({ baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "api_token" + apiKey: "api_token" }); let out = ""; diff --git a/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.py b/packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.hf-inference.py similarity index 100% rename from packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.py rename to packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.openai.hf-inference.py diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/0.default.sh b/packages/tasks-gen/snippets-fixtures/text-to-image/0.curl.hf-inference.sh similarity index 100% rename from packages/tasks-gen/snippets-fixtures/text-to-image/0.default.sh rename to packages/tasks-gen/snippets-fixtures/text-to-image/0.curl.hf-inference.sh diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface.js.hf-inference.js b/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface.js.hf-inference.js new file mode 100644 index 0000000000..f506de747a --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface.js.hf-inference.js @@ -0,0 +1,11 @@ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("api_token"); + +const image = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + inputs: "Astronaut riding a horse", + parameters: { num_inference_steps: 5 }, + provider: "hf-inference", +}); +/// Use the generated image (it's a Blob) diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.hf-inference.py b/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.hf-inference.py new file mode 100644 index 0000000000..2a488324f8 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.hf-inference.py @@ -0,0 +1,12 @@ +from huggingface_hub import InferenceClient + +client = InferenceClient( + provider="hf-inference", + api_key="api_token" +) + +# output is a PIL.Image object +image = client.text_to_image( + "Astronaut riding a horse", + model="black-forest-labs/FLUX.1-schnell" +) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py b/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py deleted file mode 100644 index a0914bc5e8..0000000000 --- a/packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py +++ /dev/null @@ -1,5 +0,0 @@ -from huggingface_hub import InferenceClient -client = InferenceClient("black-forest-labs/FLUX.1-schnell", token="api_token") - -# output is a PIL.Image object -image = client.text_to_image("Astronaut riding a horse") \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/0.default.js b/packages/tasks-gen/snippets-fixtures/text-to-image/1.fetch.hf-inference.js similarity index 100% rename from packages/tasks-gen/snippets-fixtures/text-to-image/0.default.js rename to packages/tasks-gen/snippets-fixtures/text-to-image/1.fetch.hf-inference.js diff --git a/packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.py b/packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.hf-inference.py similarity index 99% rename from packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.py rename to packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.hf-inference.py index e71ec19361..40367b2123 100644 --- a/packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.py +++ b/packages/tasks-gen/snippets-fixtures/text-to-image/1.requests.hf-inference.py @@ -6,6 +6,7 @@ def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content + image_bytes = query({ "inputs": "Astronaut riding a horse", }) diff --git a/packages/tasks/src/index.ts b/packages/tasks/src/index.ts index c350fecd2e..921acfa941 100644 --- a/packages/tasks/src/index.ts +++ b/packages/tasks/src/index.ts @@ -58,3 +58,5 @@ export type { LocalApp, LocalAppKey, LocalAppSnippet } from "./local-apps.js"; export { DATASET_LIBRARIES_UI_ELEMENTS } from "./dataset-libraries.js"; export type { DatasetLibraryUiElement, DatasetLibraryKey } from "./dataset-libraries.js"; + +export * from "./inference-providers.js"; diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts new file mode 100644 index 0000000000..c1deffbbb8 --- /dev/null +++ b/packages/tasks/src/inference-providers.ts @@ -0,0 +1,16 @@ +export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const; + +export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; + +export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`; + +/** + * URL to set as baseUrl in the OpenAI SDK. + * + * TODO(Expose this from HfInference in the future?) + */ +export function openAIbaseUrl(provider: InferenceProvider): string { + return provider === "hf-inference" + ? "https://api-inference.huggingface.co/v1/" + : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider); +} diff --git a/packages/tasks/src/snippets/curl.ts b/packages/tasks/src/snippets/curl.ts index f3ba735f3d..4fad126799 100644 --- a/packages/tasks/src/snippets/curl.ts +++ b/packages/tasks/src/snippets/curl.ts @@ -1,20 +1,35 @@ +import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js"; import type { PipelineType } from "../pipelines.js"; import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js"; import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; -export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `curl https://api-inference.huggingface.co/models/${model.id} \\ +export const snippetBasic = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + if (provider !== "hf-inference") { + return []; + } + return [ + { + client: "curl", + content: `\ +curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ -d '{"inputs": ${getModelInputSnippet(model, true)}}' \\ -H 'Content-Type: application/json' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, -}); + }, + ]; +}; export const snippetTextGeneration = ( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: { streaming?: boolean; messages?: ChatCompletionInputMessage[]; @@ -22,8 +37,13 @@ export const snippetTextGeneration = ( max_tokens?: GenerationParameters["max_tokens"]; top_p?: GenerationParameters["top_p"]; } -): InferenceSnippet => { +): InferenceSnippet[] => { if (model.tags.includes("conversational")) { + const baseUrl = + provider === "hf-inference" + ? `https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions` + : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions"; + // Conversational model detected, so we display a code snippet that features the Messages API const streaming = opts?.streaming ?? true; const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; @@ -34,8 +54,10 @@ export const snippetTextGeneration = ( max_tokens: opts?.max_tokens ?? 500, ...(opts?.top_p ? { top_p: opts.top_p } : undefined), }; - return { - content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\ + return [ + { + client: "curl", + content: `curl '${baseUrl}' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\ -H 'Content-Type: application/json' \\ --data '{ @@ -52,34 +74,64 @@ export const snippetTextGeneration = ( })}, "stream": ${!!streaming} }'`, - }; + }, + ]; } else { - return snippetBasic(model, accessToken); + return snippetBasic(model, accessToken, provider); } }; -export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `curl https://api-inference.huggingface.co/models/${model.id} \\ +export const snippetZeroShotClassification = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + if (provider !== "hf-inference") { + return []; + } + return [ + { + client: "curl", + content: `curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ -d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\ -H 'Content-Type: application/json' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, -}); + }, + ]; +}; -export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `curl https://api-inference.huggingface.co/models/${model.id} \\ +export const snippetFile = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + if (provider !== "hf-inference") { + return []; + } + return [ + { + client: "curl", + content: `curl https://api-inference.huggingface.co/models/${model.id} \\ -X POST \\ --data-binary '@${getModelInputSnippet(model, true, true)}' \\ -H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`, -}); + }, + ]; +}; export const curlSnippets: Partial< Record< PipelineType, - (model: ModelDataMinimal, accessToken: string, opts?: Record) => InferenceSnippet + ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider, + opts?: Record + ) => InferenceSnippet[] > > = { - // Same order as in js/src/lib/interfaces/Types.ts + // Same order as in tasks/src/pipelines.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, @@ -108,13 +160,10 @@ export const curlSnippets: Partial< export function getCurlInferenceSnippet( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: Record -): InferenceSnippet { +): InferenceSnippet[] { return model.pipeline_tag && model.pipeline_tag in curlSnippets - ? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" } - : { content: "" }; -} - -export function hasCurlInferenceSnippet(model: Pick): boolean { - return !!model.pipeline_tag && model.pipeline_tag in curlSnippets; + ? curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? [] + : []; } diff --git a/packages/tasks/src/snippets/js.ts b/packages/tasks/src/snippets/js.ts index 9707525822..a1389b6f92 100644 --- a/packages/tasks/src/snippets/js.ts +++ b/packages/tasks/src/snippets/js.ts @@ -1,11 +1,54 @@ +import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js"; import type { PipelineType } from "../pipelines.js"; import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js"; import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; -export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `async function query(data) { +const HFJS_METHODS: Record = { + "text-classification": "textClassification", + "token-classification": "tokenClassification", + "table-question-answering": "tableQuestionAnswering", + "question-answering": "questionAnswering", + translation: "translation", + summarization: "summarization", + "feature-extraction": "featureExtraction", + "text-generation": "textGeneration", + "text2text-generation": "textGeneration", + "fill-mask": "fillMask", + "sentence-similarity": "sentenceSimilarity", +}; + +export const snippetBasic = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + return [ + ...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS + ? [ + { + client: "huggingface.js", + content: `\ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("${accessToken || `{API_TOKEN}`}"); + +const output = await client.${HFJS_METHODS[model.pipeline_tag]}({ + model: "${model.id}", + inputs: ${getModelInputSnippet(model)}, + provider: "${provider}", +}); + +console.log(output) +`, + }, + ] + : []), + { + client: "fetch", + content: `\ +async function query(data) { const response = await fetch( "https://api-inference.huggingface.co/models/${model.id}", { @@ -24,11 +67,14 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { console.log(JSON.stringify(response)); });`, -}); + }, + ]; +}; export const snippetTextGeneration = ( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: { streaming?: boolean; messages?: ChatCompletionInputMessage[]; @@ -36,7 +82,7 @@ export const snippetTextGeneration = ( max_tokens?: GenerationParameters["max_tokens"]; top_p?: GenerationParameters["top_p"]; } -): InferenceSnippet | InferenceSnippet[] => { +): InferenceSnippet[] => { if (model.tags.includes("conversational")) { // Conversational model detected, so we display a code snippet that features the Messages API const streaming = opts?.streaming ?? true; @@ -67,6 +113,7 @@ let out = ""; const stream = client.chatCompletionStream({ model: "${model.id}", messages: ${messagesStr}, + provider: "${provider}", ${configStr} }); @@ -83,8 +130,8 @@ for await (const chunk of stream) { content: `import { OpenAI } from "openai"; const client = new OpenAI({ - baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "${accessToken || `{API_TOKEN}`}" + baseURL: "${openAIbaseUrl(provider)}", + apiKey: "${accessToken || `{API_TOKEN}`}" }); let out = ""; @@ -116,6 +163,7 @@ const client = new HfInference("${accessToken || `{API_TOKEN}`}"); const chatCompletion = await client.chatCompletion({ model: "${model.id}", messages: ${messagesStr}, + provider: "${provider}", ${configStr} }); @@ -126,8 +174,8 @@ console.log(chatCompletion.choices[0].message);`, content: `import { OpenAI } from "openai"; const client = new OpenAI({ - baseURL: "https://api-inference.huggingface.co/v1/", - apiKey: "${accessToken || `{API_TOKEN}`}" + baseURL: "${openAIbaseUrl(provider)}", + apiKey: "${accessToken || `{API_TOKEN}`}" }); const chatCompletion = await client.chat.completions.create({ @@ -141,36 +189,66 @@ console.log(chatCompletion.choices[0].message);`, ]; } } else { - return snippetBasic(model, accessToken); + return snippetBasic(model, accessToken, provider); } }; -export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `async function query(data) { - const response = await fetch( - "https://api-inference.huggingface.co/models/${model.id}", +export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => { + return [ { - headers: { - Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", - "Content-Type": "application/json", - }, - method: "POST", - body: JSON.stringify(data), + client: "fetch", + content: `async function query(data) { + const response = await fetch( + "https://api-inference.huggingface.co/models/${model.id}", + { + headers: { + Authorization: "Bearer ${accessToken || `{API_TOKEN}`}", + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify(data), + } + ); + const result = await response.json(); + return result; } - ); - const result = await response.json(); - return result; -} + + query({"inputs": ${getModelInputSnippet( + model + )}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => { + console.log(JSON.stringify(response)); + });`, + }, + ]; +}; -query({"inputs": ${getModelInputSnippet( - model - )}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}).then((response) => { - console.log(JSON.stringify(response)); -});`, -}); +export const snippetTextToImage = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + return [ + { + client: "huggingface.js", + content: `\ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("${accessToken || `{API_TOKEN}`}"); -export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `async function query(data) { +const image = await client.textToImage({ + model: "${model.id}", + inputs: ${getModelInputSnippet(model)}, + parameters: { num_inference_steps: 5 }, + provider: "${provider}", +}); +/// Use the generated image (it's a Blob) +`, + }, + ...(provider === "hf-inference" + ? [ + { + client: "fetch", + content: `async function query(data) { const response = await fetch( "https://api-inference.huggingface.co/models/${model.id}", { @@ -188,9 +266,20 @@ export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string) query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { // Use image });`, -}); + }, + ] + : []), + ]; +}; -export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => { +export const snippetTextToAudio = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + if (provider !== "hf-inference") { + return []; + } const commonSnippet = `async function query(data) { const response = await fetch( "https://api-inference.huggingface.co/models/${model.id}", @@ -204,22 +293,27 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string) } );`; if (model.library_name === "transformers") { - return { - content: - commonSnippet + - ` + return [ + { + client: "fetch", + content: + commonSnippet + + ` const result = await response.blob(); return result; } query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { // Returns a byte object of the Audio wavform. Use it directly! });`, - }; + }, + ]; } else { - return { - content: - commonSnippet + - ` + return [ + { + client: "fetch", + content: + commonSnippet + + ` const result = await response.json(); return result; } @@ -227,12 +321,51 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string) query({"inputs": ${getModelInputSnippet(model)}}).then((response) => { console.log(JSON.stringify(response)); });`, - }; + }, + ]; } }; -export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ - content: `async function query(filename) { +export const snippetAutomaticSpeechRecognition = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + return [ + { + client: "huggingface.js", + content: `\ +import { HfInference } from "@huggingface/inference"; + +const client = new HfInference("${accessToken || `{API_TOKEN}`}"); + +const data = fs.readFileSync(${getModelInputSnippet(model)}); + +const output = await client.automaticSpeechRecognition({ + data, + model: "${model.id}", + provider: "${provider}", +}); + +console.log(output); +`, + }, + ...(provider === "hf-inference" ? snippetFile(model, accessToken, provider) : []), + ]; +}; + +export const snippetFile = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + if (provider !== "hf-inference") { + return []; + } + return [ + { + client: "fetch", + content: `async function query(filename) { const data = fs.readFileSync(filename); const response = await fetch( "https://api-inference.huggingface.co/models/${model.id}", @@ -252,7 +385,9 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): Infer query(${getModelInputSnippet(model)}).then((response) => { console.log(JSON.stringify(response)); });`, -}); + }, + ]; +}; export const jsSnippets: Partial< Record< @@ -260,11 +395,12 @@ export const jsSnippets: Partial< ( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: Record - ) => InferenceSnippet | InferenceSnippet[] + ) => InferenceSnippet[] > > = { - // Same order as in js/src/lib/interfaces/Types.ts + // Same order as in tasks/src/pipelines.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, @@ -278,7 +414,7 @@ export const jsSnippets: Partial< "text2text-generation": snippetBasic, "fill-mask": snippetBasic, "sentence-similarity": snippetBasic, - "automatic-speech-recognition": snippetFile, + "automatic-speech-recognition": snippetAutomaticSpeechRecognition, "text-to-image": snippetTextToImage, "text-to-speech": snippetTextToAudio, "text-to-audio": snippetTextToAudio, @@ -293,13 +429,10 @@ export const jsSnippets: Partial< export function getJsInferenceSnippet( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: Record -): InferenceSnippet | InferenceSnippet[] { +): InferenceSnippet[] { return model.pipeline_tag && model.pipeline_tag in jsSnippets - ? jsSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" } - : { content: "" }; -} - -export function hasJsInferenceSnippet(model: ModelDataMinimal): boolean { - return !!model.pipeline_tag && model.pipeline_tag in jsSnippets; + ? jsSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? [] + : []; } diff --git a/packages/tasks/src/snippets/python.ts b/packages/tasks/src/snippets/python.ts index bdb148e391..53f5da8e8a 100644 --- a/packages/tasks/src/snippets/python.ts +++ b/packages/tasks/src/snippets/python.ts @@ -1,17 +1,23 @@ +import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js"; import type { PipelineType } from "../pipelines.js"; import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js"; import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; -const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string => - `from huggingface_hub import InferenceClient -client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}") -`; +const snippetImportInferenceClient = (accessToken: string, provider: InferenceProvider): string => + `\ +from huggingface_hub import InferenceClient + +client = InferenceClient( + provider="${provider}", + api_key="${accessToken || "{API_TOKEN}"}" +)`; export const snippetConversational = ( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: { streaming?: boolean; messages?: ChatCompletionInputMessage[]; @@ -39,9 +45,8 @@ export const snippetConversational = ( return [ { client: "huggingface_hub", - content: `from huggingface_hub import InferenceClient - -client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}") + content: `\ +${snippetImportInferenceClient(accessToken, provider)} messages = ${messagesStr} @@ -60,7 +65,7 @@ for chunk in stream: content: `from openai import OpenAI client = OpenAI( - base_url="https://api-inference.huggingface.co/v1/", + base_url="${openAIbaseUrl(provider)}", api_key="${accessToken || "{API_TOKEN}"}" ) @@ -81,9 +86,8 @@ for chunk in stream: return [ { client: "huggingface_hub", - content: `from huggingface_hub import InferenceClient - -client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}") + content: `\ +${snippetImportInferenceClient(accessToken, provider)} messages = ${messagesStr} @@ -100,7 +104,7 @@ print(completion.choices[0].message)`, content: `from openai import OpenAI client = OpenAI( - base_url="https://api-inference.huggingface.co/v1/", + base_url="${openAIbaseUrl(provider)}", api_key="${accessToken || "{API_TOKEN}"}" ) @@ -118,8 +122,11 @@ print(completion.choices[0].message)`, } }; -export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): +export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @@ -127,10 +134,15 @@ output = query({ "inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, })`, -}); + }, + ]; +}; -export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(data): +export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(data): with open(data["image_path"], "rb") as f: img = f.read() payload={ @@ -144,40 +156,85 @@ output = query({ "image_path": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["cat", "dog", "llama"]}, })`, -}); + }, + ]; +}; -export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): +export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, })`, -}); - -export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(filename): - with open(filename, "rb") as f: - data = f.read() - response = requests.post(API_URL, headers=headers, data=data) - return response.json() - -output = query(${getModelInputSnippet(model)})`, -}); - -export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [ - { - client: "huggingface_hub", - content: `${snippetImportInferenceClient(model, accessToken)} + }, + ]; +}; + +export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(filename): + with open(filename, "rb") as f: + data = f.read() + response = requests.post(API_URL, headers=headers, data=data) + return response.json() + + output = query(${getModelInputSnippet(model)})`, + }, + ]; +}; + +export const snippetTextToImage = ( + model: ModelDataMinimal, + accessToken: string, + provider: InferenceProvider +): InferenceSnippet[] => { + return [ + { + client: "huggingface_hub", + content: `\ +${snippetImportInferenceClient(accessToken, provider)} + # output is a PIL.Image object -image = client.text_to_image(${getModelInputSnippet(model)})`, +image = client.text_to_image( + ${getModelInputSnippet(model)}, + model="${model.id}" +)`, + }, + ...(provider === "fal-ai" + ? [ + { + client: "fal-client", + content: `\ +import fal_client + +result = fal_client.subscribe( + # replace with correct id from fal.ai + "fal-ai/${model.id}", + arguments={ + "prompt": ${getModelInputSnippet(model)}, }, - { - client: "requests", - content: `def query(payload): +) +print(result) +`, + }, + ] + : []), + ...(provider === "hf-inference" + ? [ + { + client: "requests", + content: `\ +def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content + image_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) @@ -186,25 +243,35 @@ image_bytes = query({ import io from PIL import Image image = Image.open(io.BytesIO(image_bytes))`, - }, -]; + }, + ] + : []), + ]; +}; -export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): - response = requests.post(API_URL, headers=headers, json=payload) - return response.content -response = query({ - "inputs": {"data": ${getModelInputSnippet(model)}}, -})`, -}); +export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(payload): + response = requests.post(API_URL, headers=headers, json=payload) + return response.content + response = query({ + "inputs": {"data": ${getModelInputSnippet(model)}}, + })`, + }, + ]; +}; -export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet => { +export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => { // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged // with the latest update to inference-api (IA). // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate. if (model.library_name === "transformers") { - return { - content: `def query(payload): + return [ + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content @@ -214,10 +281,13 @@ audio_bytes = query({ # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio_bytes)`, - }; + }, + ]; } else { - return { - content: `def query(payload): + return [ + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @@ -227,12 +297,16 @@ audio, sampling_rate = query({ # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio, rate=sampling_rate)`, - }; + }, + ]; } }; -export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): +export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(payload): with open(payload["image"], "rb") as f: img = f.read() payload["image"] = base64.b64encode(img).decode("utf-8") @@ -242,7 +316,9 @@ export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): Infer output = query({ "inputs": ${getModelInputSnippet(model)}, })`, -}); + }, + ]; +}; export const pythonSnippets: Partial< Record< @@ -250,8 +326,9 @@ export const pythonSnippets: Partial< ( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: Record - ) => InferenceSnippet | InferenceSnippet[] + ) => InferenceSnippet[] > > = { // Same order as in tasks/src/pipelines.ts @@ -287,35 +364,37 @@ export const pythonSnippets: Partial< export function getPythonInferenceSnippet( model: ModelDataMinimal, accessToken: string, + provider: InferenceProvider, opts?: Record -): InferenceSnippet | InferenceSnippet[] { +): InferenceSnippet[] { if (model.tags.includes("conversational")) { // Conversational model detected, so we display a code snippet that features the Messages API - return snippetConversational(model, accessToken, opts); + return snippetConversational(model, accessToken, provider, opts); } else { - let snippets = + const snippets = model.pipeline_tag && model.pipeline_tag in pythonSnippets - ? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" } - : { content: "" }; + ? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? [] + : []; - snippets = Array.isArray(snippets) ? snippets : [snippets]; + const baseUrl = + provider === "hf-inference" + ? `https://api-inference.huggingface.co/models/${model.id}` + : HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider); return snippets.map((snippet) => { return { ...snippet, - content: snippet.content.includes("requests") - ? `import requests + content: + snippet.client === "requests" + ? `\ +import requests -API_URL = "https://api-inference.huggingface.co/models/${model.id}" +API_URL = "${baseUrl}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} ${snippet.content}` - : snippet.content, + : snippet.content, }; }); } } - -export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean { - return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; -} diff --git a/packages/tasks/src/snippets/types.ts b/packages/tasks/src/snippets/types.ts index c6a78c278d..5cd807badb 100644 --- a/packages/tasks/src/snippets/types.ts +++ b/packages/tasks/src/snippets/types.ts @@ -12,5 +12,5 @@ export type ModelDataMinimal = Pick< export interface InferenceSnippet { content: string; - client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets + client: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets }