Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,10 @@ You can use any Chat Completion API-compatible provider with the `chatCompletion
```typescript
// Chat Completion Example
const MISTRAL_KEY = process.env.MISTRAL_KEY;
const hf = new InferenceClient(MISTRAL_KEY);
const ep = hf.endpoint("https://api.mistral.ai");
const stream = ep.chatCompletionStream({
const hf = new InferenceClient(MISTRAL_KEY, {
endpointUrl: "https://api.mistral.ai",
});
const stream = hf.chatCompletionStream({
model: "mistral-tiny",
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
});
Expand Down
32 changes: 23 additions & 9 deletions packages/inference/src/snippets/getInferenceSnippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ export type InferenceSnippetOptions = {
streaming?: boolean;
billTo?: string;
accessToken?: string;
directRequest?: boolean;
directRequest?: boolean; // to bypass HF routing and call the provider directly
endpointUrl?: string; // to call a local endpoint directly
} & Record<string, unknown>;

const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
Expand Down Expand Up @@ -53,6 +54,7 @@ interface TemplateParams {
methodName?: string; // specific to snippetBasic
importBase64?: boolean; // specific to snippetImportRequests
importJson?: boolean; // specific to snippetImportRequests
endpointUrl?: string;
}

// Helpers to find + load templates
Expand Down Expand Up @@ -172,6 +174,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
{
accessToken: accessTokenOrPlaceholder,
provider,
endpointUrl: opts?.endpointUrl,
...inputs,
} as RequestArgs,
inferenceProviderMapping,
Expand Down Expand Up @@ -217,6 +220,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
provider,
providerModelId: providerModelId ?? model.id,
billTo: opts?.billTo,
endpointUrl: opts?.endpointUrl,
};

/// Iterate over clients => check if a snippet exists => generate
Expand Down Expand Up @@ -265,7 +269,14 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar

/// Replace access token placeholder
if (snippet.includes(placeholder)) {
snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider);
snippet = replaceAccessTokenPlaceholder(
opts?.directRequest,
placeholder,
snippet,
language,
provider,
opts?.endpointUrl
);
}

/// Snippet is ready!
Expand Down Expand Up @@ -444,21 +455,24 @@ function replaceAccessTokenPlaceholder(
placeholder: string,
snippet: string,
language: InferenceSnippetLanguage,
provider: InferenceProviderOrPolicy
provider: InferenceProviderOrPolicy,
endpointUrl?: string
): string {
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.

// Determine if HF_TOKEN or specific provider token should be used
const useHfToken =
provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
(!directRequest && // if explicit directRequest => use provider-specific token
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN

!endpointUrl && // custom endpointUrl => use a generic API_TOKEN
(provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
(!directRequest && // if explicit directRequest => use provider-specific token
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
snippet.includes("https://router.huggingface.co")))); // explicit routed request => use $HF_TOKEN
const accessTokenEnvVar = useHfToken
? "HF_TOKEN" // e.g. routed request or hf-inference
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
: endpointUrl
? "API_TOKEN"
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"

// Replace the placeholder with the env variable
if (language === "sh") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
const client = new InferenceClient("{{ accessToken }}");

const output = await client.{{ methodName }}({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
model: "{{ model.id }}",
inputs: {{ inputs.asObj.inputs }},
provider: "{{ provider }}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
const data = fs.readFileSync({{inputs.asObj.inputs}});

const output = await client.{{ methodName }}({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
data,
model: "{{ model.id }}",
provider: "{{ provider }}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
const data = fs.readFileSync({{inputs.asObj.inputs}});

const output = await client.{{ methodName }}({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
data,
model: "{{ model.id }}",
provider: "{{ provider }}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
const client = new InferenceClient("{{ accessToken }}");

const chatCompletion = await client.chatCompletion({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
provider: "{{ provider }}",
model: "{{ model.id }}",
{{ inputs.asTsString }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
let out = "";

const stream = client.chatCompletionStream({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
provider: "{{ provider }}",
model: "{{ model.id }}",
{{ inputs.asTsString }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
const client = new InferenceClient("{{ accessToken }}");

const image = await client.textToImage({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
provider: "{{ provider }}",
model: "{{ model.id }}",
inputs: {{ inputs.asObj.inputs }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
const client = new InferenceClient("{{ accessToken }}");

const audio = await client.textToSpeech({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
provider: "{{ provider }}",
model: "{{ model.id }}",
inputs: {{ inputs.asObj.inputs }},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
const client = new InferenceClient("{{ accessToken }}");

const video = await client.textToVideo({
{% if endpointUrl %}
endpointUrl: "{{ endpointUrl }}",
{% endif %}
provider: "{{ provider }}",
model: "{{ model.id }}",
inputs: {{ inputs.asObj.inputs }},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
{% if endpointUrl %}
base_url="{{ baseUrl }}",
{% endif %}
provider="{{ provider }}",
api_key="{{ accessToken }}",
{% if billTo %}
Expand Down
12 changes: 12 additions & 0 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ const TEST_CASES: {
providers: ["hf-inference", "fireworks-ai"],
opts: { streaming: true },
},
{
testName: "conversational-llm-custom-endpoint",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference"],
opts: { endpointUrl: "http://localhost:8080/v1" },
},
{
testName: "document-question-answering",
task: "document-question-answering",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient(process.env.API_TOKEN);

const chatCompletion = await client.chatCompletion({
endpointUrl: "http://localhost:8080/v1",
provider: "hf-inference",
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?",
},
],
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "http://localhost:8080/v1",
apiKey: process.env.API_TOKEN,
});

const chatCompletion = await client.chat.completions.create({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?",
},
],
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
base_url="http://localhost:8080/v1",
provider="hf-inference",
api_key=os.environ["API_TOKEN"],
)

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "What is the capital of France?"
}
],
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from openai import OpenAI

client = OpenAI(
base_url="http://localhost:8080/v1",
api_key=os.environ["API_TOKEN"],
)

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "What is the capital of France?"
}
],
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import requests

API_URL = "http://localhost:8080/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['API_TOKEN']}",
}

def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

response = query({
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"model": "meta-llama/Llama-3.1-8B-Instruct"
})

print(response["choices"][0]["message"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
curl http://localhost:8080/v1/chat/completions \
-H "Authorization: Bearer $API_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"model": "meta-llama/Llama-3.1-8B-Instruct",
"stream": false
}'