diff --git a/packages/inference/README.md b/packages/inference/README.md index 5425204d05..0ea60b2be7 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -651,9 +651,10 @@ You can use any Chat Completion API-compatible provider with the `chatCompletion ```typescript // Chat Completion Example const MISTRAL_KEY = process.env.MISTRAL_KEY; -const hf = new InferenceClient(MISTRAL_KEY); -const ep = hf.endpoint("https://api.mistral.ai"); -const stream = ep.chatCompletionStream({ +const hf = new InferenceClient(MISTRAL_KEY, { + endpointUrl: "https://api.mistral.ai", +}); +const stream = hf.chatCompletionStream({ model: "mistral-tiny", messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }], }); diff --git a/packages/inference/src/snippets/getInferenceSnippets.ts b/packages/inference/src/snippets/getInferenceSnippets.ts index 19d8de3afe..991e4209fe 100644 --- a/packages/inference/src/snippets/getInferenceSnippets.ts +++ b/packages/inference/src/snippets/getInferenceSnippets.ts @@ -18,7 +18,8 @@ export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string; - directRequest?: boolean; + directRequest?: boolean; // to bypass HF routing and call the provider directly + endpointUrl?: string; // to call a local endpoint directly } & Record; const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const; @@ -53,6 +54,7 @@ interface TemplateParams { methodName?: string; // specific to snippetBasic importBase64?: boolean; // specific to snippetImportRequests importJson?: boolean; // specific to snippetImportRequests + endpointUrl?: string; } // Helpers to find + load templates @@ -172,6 +174,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar { accessToken: accessTokenOrPlaceholder, provider, + endpointUrl: opts?.endpointUrl, ...inputs, } as RequestArgs, inferenceProviderMapping, @@ -217,6 +220,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar provider, providerModelId: providerModelId ?? model.id, billTo: opts?.billTo, + endpointUrl: opts?.endpointUrl, }; /// Iterate over clients => check if a snippet exists => generate @@ -265,7 +269,14 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar /// Replace access token placeholder if (snippet.includes(placeholder)) { - snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider); + snippet = replaceAccessTokenPlaceholder( + opts?.directRequest, + placeholder, + snippet, + language, + provider, + opts?.endpointUrl + ); } /// Snippet is ready! @@ -444,21 +455,24 @@ function replaceAccessTokenPlaceholder( placeholder: string, snippet: string, language: InferenceSnippetLanguage, - provider: InferenceProviderOrPolicy + provider: InferenceProviderOrPolicy, + endpointUrl?: string ): string { // If "opts.accessToken" is not set, the snippets are generated with a placeholder. // Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable. // Determine if HF_TOKEN or specific provider token should be used const useHfToken = - provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN - (!directRequest && // if explicit directRequest => use provider-specific token - (!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN - snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN - + !endpointUrl && // custom endpointUrl => use a generic API_TOKEN + (provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN + (!directRequest && // if explicit directRequest => use provider-specific token + (!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN + snippet.includes("https://router.huggingface.co")))); // explicit routed request => use $HF_TOKEN const accessTokenEnvVar = useHfToken ? "HF_TOKEN" // e.g. routed request or hf-inference - : provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY" + : endpointUrl + ? "API_TOKEN" + : provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY" // Replace the placeholder with the env variable if (language === "sh") { diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja index 830713e6af..71abb328af 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient("{{ accessToken }}"); const output = await client.{{ methodName }}({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} model: "{{ model.id }}", inputs: {{ inputs.asObj.inputs }}, provider: "{{ provider }}", diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja index 7b34e11eb1..ba070486be 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}"); const data = fs.readFileSync({{inputs.asObj.inputs}}); const output = await client.{{ methodName }}({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} data, model: "{{ model.id }}", provider: "{{ provider }}", diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja index 7b34e11eb1..ba070486be 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}"); const data = fs.readFileSync({{inputs.asObj.inputs}}); const output = await client.{{ methodName }}({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} data, model: "{{ model.id }}", provider: "{{ provider }}", diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja index 52d9029625..cbf306926f 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient("{{ accessToken }}"); const chatCompletion = await client.chatCompletion({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} provider: "{{ provider }}", model: "{{ model.id }}", {{ inputs.asTsString }} diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja index ac42a557cd..57f5e3aa06 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}"); let out = ""; const stream = client.chatCompletionStream({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} provider: "{{ provider }}", model: "{{ model.id }}", {{ inputs.asTsString }} diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja index e24e2b13d5..4abe1340a1 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient("{{ accessToken }}"); const image = await client.textToImage({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} provider: "{{ provider }}", model: "{{ model.id }}", inputs: {{ inputs.asObj.inputs }}, diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja index 6ac3f3e9ba..389a843724 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient("{{ accessToken }}"); const audio = await client.textToSpeech({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} provider: "{{ provider }}", model: "{{ model.id }}", inputs: {{ inputs.asObj.inputs }}, diff --git a/packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja b/packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja index 578d4309db..8063b2017b 100644 --- a/packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja +++ b/packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient("{{ accessToken }}"); const video = await client.textToVideo({ +{% if endpointUrl %} + endpointUrl: "{{ endpointUrl }}", +{% endif %} provider: "{{ provider }}", model: "{{ model.id }}", inputs: {{ inputs.asObj.inputs }}, diff --git a/packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja b/packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja index 4e09aef96f..df7e88d4a9 100644 --- a/packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja +++ b/packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja @@ -1,6 +1,9 @@ from huggingface_hub import InferenceClient client = InferenceClient( +{% if endpointUrl %} + base_url="{{ baseUrl }}", +{% endif %} provider="{{ provider }}", api_key="{{ accessToken }}", {% if billTo %} diff --git a/packages/tasks-gen/scripts/generate-snippets-fixtures.ts b/packages/tasks-gen/scripts/generate-snippets-fixtures.ts index c69629631b..e5df80f8d8 100644 --- a/packages/tasks-gen/scripts/generate-snippets-fixtures.ts +++ b/packages/tasks-gen/scripts/generate-snippets-fixtures.ts @@ -95,6 +95,18 @@ const TEST_CASES: { providers: ["hf-inference", "fireworks-ai"], opts: { streaming: true }, }, + { + testName: "conversational-llm-custom-endpoint", + task: "conversational", + model: { + id: "meta-llama/Llama-3.1-8B-Instruct", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }, + providers: ["hf-inference"], + opts: { endpointUrl: "http://localhost:8080/v1" }, + }, { testName: "document-question-answering", task: "document-question-answering", diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/huggingface.js/0.hf-inference.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/huggingface.js/0.hf-inference.js new file mode 100644 index 0000000000..12c7d989c2 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/huggingface.js/0.hf-inference.js @@ -0,0 +1,17 @@ +import { InferenceClient } from "@huggingface/inference"; + +const client = new InferenceClient(process.env.API_TOKEN); + +const chatCompletion = await client.chatCompletion({ + endpointUrl: "http://localhost:8080/v1", + provider: "hf-inference", + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [ + { + role: "user", + content: "What is the capital of France?", + }, + ], +}); + +console.log(chatCompletion.choices[0].message); \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/openai/0.hf-inference.js b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/openai/0.hf-inference.js new file mode 100644 index 0000000000..63b3d47136 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/openai/0.hf-inference.js @@ -0,0 +1,18 @@ +import { OpenAI } from "openai"; + +const client = new OpenAI({ + baseURL: "http://localhost:8080/v1", + apiKey: process.env.API_TOKEN, +}); + +const chatCompletion = await client.chat.completions.create({ + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [ + { + role: "user", + content: "What is the capital of France?", + }, + ], +}); + +console.log(chatCompletion.choices[0].message); \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/huggingface_hub/0.hf-inference.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/huggingface_hub/0.hf-inference.py new file mode 100644 index 0000000000..2d9d5953a2 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/huggingface_hub/0.hf-inference.py @@ -0,0 +1,20 @@ +import os +from huggingface_hub import InferenceClient + +client = InferenceClient( + base_url="http://localhost:8080/v1", + provider="hf-inference", + api_key=os.environ["API_TOKEN"], +) + +completion = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "What is the capital of France?" + } + ], +) + +print(completion.choices[0].message) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/openai/0.hf-inference.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/openai/0.hf-inference.py new file mode 100644 index 0000000000..60161ab968 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/openai/0.hf-inference.py @@ -0,0 +1,19 @@ +import os +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8080/v1", + api_key=os.environ["API_TOKEN"], +) + +completion = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "What is the capital of France?" + } + ], +) + +print(completion.choices[0].message) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/requests/0.hf-inference.py b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/requests/0.hf-inference.py new file mode 100644 index 0000000000..5a7e4ce0e6 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/requests/0.hf-inference.py @@ -0,0 +1,23 @@ +import os +import requests + +API_URL = "http://localhost:8080/v1/chat/completions" +headers = { + "Authorization": f"Bearer {os.environ['API_TOKEN']}", +} + +def query(payload): + response = requests.post(API_URL, headers=headers, json=payload) + return response.json() + +response = query({ + "messages": [ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + "model": "meta-llama/Llama-3.1-8B-Instruct" +}) + +print(response["choices"][0]["message"]) \ No newline at end of file diff --git a/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/sh/curl/0.hf-inference.sh b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/sh/curl/0.hf-inference.sh new file mode 100644 index 0000000000..1c36a1cad7 --- /dev/null +++ b/packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/sh/curl/0.hf-inference.sh @@ -0,0 +1,13 @@ +curl http://localhost:8080/v1/chat/completions \ + -H "Authorization: Bearer $API_TOKEN" \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + "model": "meta-llama/Llama-3.1-8B-Instruct", + "stream": false + }' \ No newline at end of file