Skip to content

Commit 782078b

Browse files
committed
Allow to provide accessToken
1 parent c3d664d commit 782078b

File tree

8 files changed

+130
-9
lines changed

8 files changed

+130
-9
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.j
1414
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
1515
import { templates } from "./templates.exported.js";
1616

17-
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
17+
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string } & Record<
18+
string,
19+
unknown
20+
>;
1821

1922
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
2023
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -150,14 +153,15 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
150153
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
151154
return [];
152155
}
156+
const accessTokenOrPlaceholder = opts?.accessToken ?? ACCESS_TOKEN_PLACEHOLDER;
153157

154158
/// Prepare inputs + make request
155159
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
156160
const request = makeRequestOptionsFromResolvedModel(
157161
providerModelId,
158162
providerHelper,
159163
{
160-
accessToken: ACCESS_TOKEN_PLACEHOLDER,
164+
accessToken: accessTokenOrPlaceholder,
161165
provider,
162166
...inputs,
163167
} as RequestArgs,
@@ -182,7 +186,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
182186

183187
/// Prepare template injection data
184188
const params: TemplateParams = {
185-
accessToken: ACCESS_TOKEN_PLACEHOLDER,
189+
accessToken: accessTokenOrPlaceholder,
186190
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization,
187191
baseUrl: removeSuffix(request.url, "/chat/completions"),
188192
fullUrl: request.url,
@@ -251,7 +255,9 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
251255
}
252256

253257
/// Replace access token placeholder
254-
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
258+
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
259+
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
260+
}
255261

256262
/// Snippet is ready!
257263
return { language, client: client as string, content: snippet };
@@ -429,8 +435,8 @@ function replaceAccessTokenPlaceholder(
429435
language: InferenceSnippetLanguage,
430436
provider: InferenceProviderOrPolicy
431437
): string {
432-
// The snippets are generated with a placeholder in place of the access token.
433-
// Once snippets are rendered, we replace the placeholder with correct code to fetch the access token from an environment variable.
438+
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
439+
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
434440

435441
// Determine if HF_TOKEN or specific provider token should be used
436442
const accessTokenEnvVar =
@@ -447,9 +453,7 @@ function replaceAccessTokenPlaceholder(
447453
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
448454
);
449455
} else if (language === "python") {
450-
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
451-
snippet = "import os\n" + snippet;
452-
}
456+
snippet = "import os\n" + snippet;
453457
snippet = snippet.replace(
454458
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
455459
`os.environ["${accessTokenEnvVar}"]` // e.g. os.environ["HF_TOKEN")

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ const TEST_CASES: {
240240
providers: ["hf-inference"],
241241
opts: { billTo: "huggingface" },
242242
},
243+
{
244+
testName: "with-access-token",
245+
task: "conversational",
246+
model: {
247+
id: "meta-llama/Llama-3.1-8B-Instruct",
248+
pipeline_tag: "text-generation",
249+
tags: ["conversational"],
250+
inference: "",
251+
},
252+
providers: ["hf-inference"],
253+
opts: { accessToken: "hf_xxx" },
254+
},
243255
{
244256
testName: "text-to-speech",
245257
task: "text-to-speech",
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient("hf_xxx");
4+
5+
const chatCompletion = await client.chatCompletion({
6+
provider: "hf-inference",
7+
model: "meta-llama/Llama-3.1-8B-Instruct",
8+
messages: [
9+
{
10+
role: "user",
11+
content: "What is the capital of France?",
12+
},
13+
],
14+
});
15+
16+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5+
apiKey: "hf_xxx",
6+
});
7+
8+
const chatCompletion = await client.chat.completions.create({
9+
model: "meta-llama/Llama-3.1-8B-Instruct",
10+
messages: [
11+
{
12+
role: "user",
13+
content: "What is the capital of France?",
14+
},
15+
],
16+
});
17+
18+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from huggingface_hub import InferenceClient
2+
3+
client = InferenceClient(
4+
provider="hf-inference",
5+
api_key="hf_xxx",
6+
)
7+
8+
completion = client.chat.completions.create(
9+
model="meta-llama/Llama-3.1-8B-Instruct",
10+
messages=[
11+
{
12+
"role": "user",
13+
"content": "What is the capital of France?"
14+
}
15+
],
16+
)
17+
18+
print(completion.choices[0].message)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from openai import OpenAI
2+
3+
client = OpenAI(
4+
base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5+
api_key="hf_xxx",
6+
)
7+
8+
completion = client.chat.completions.create(
9+
model="meta-llama/Llama-3.1-8B-Instruct",
10+
messages=[
11+
{
12+
"role": "user",
13+
"content": "What is the capital of France?"
14+
}
15+
],
16+
)
17+
18+
print(completion.choices[0].message)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import requests
2+
3+
API_URL = "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions"
4+
headers = {
5+
"Authorization": "Bearer hf_xxx",
6+
}
7+
8+
def query(payload):
9+
response = requests.post(API_URL, headers=headers, json=payload)
10+
return response.json()
11+
12+
response = query({
13+
"messages": [
14+
{
15+
"role": "user",
16+
"content": "What is the capital of France?"
17+
}
18+
],
19+
"model": "meta-llama/Llama-3.1-8B-Instruct"
20+
})
21+
22+
print(response["choices"][0]["message"])
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
curl https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions \
2+
-H 'Authorization: Bearer hf_xxx' \
3+
-H 'Content-Type: application/json' \
4+
-d '{
5+
"messages": [
6+
{
7+
"role": "user",
8+
"content": "What is the capital of France?"
9+
}
10+
],
11+
"model": "meta-llama/Llama-3.1-8B-Instruct",
12+
"stream": false
13+
}'

0 commit comments

Comments
 (0)