Skip to content

Commit 3857067

Browse files
committed
[InferenceSnippet] Take token from env variable if not set
1 parent 0416cec commit 3857067

File tree

112 files changed

+222
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+222
-113
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
121121
translation: "translation",
122122
};
123123

124+
const ACCESS_TOKEN_PLACEHOLDER = "<ACCESS_TOKEN>"; // Placeholder to replace with env variable in snippets
125+
124126
// Snippet generators
125127
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
126128
return (
@@ -149,13 +151,15 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
149151
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
150152
return [];
151153
}
154+
const accessTokenOrPlaceholder = accessToken == "" ? ACCESS_TOKEN_PLACEHOLDER : accessToken;
155+
152156
/// Prepare inputs + make request
153157
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
154158
const request = makeRequestOptionsFromResolvedModel(
155159
providerModelId,
156160
providerHelper,
157161
{
158-
accessToken,
162+
accessToken: accessTokenOrPlaceholder,
159163
provider,
160164
...inputs,
161165
} as RequestArgs,
@@ -180,7 +184,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
180184

181185
/// Prepare template injection data
182186
const params: TemplateParams = {
183-
accessToken,
187+
accessToken: accessTokenOrPlaceholder,
184188
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization,
185189
baseUrl: removeSuffix(request.url, "/chat/completions"),
186190
fullUrl: request.url,
@@ -248,6 +252,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
248252
snippet = `${importSection}\n\n${snippet}`;
249253
}
250254

255+
/// Replace access token placeholder
256+
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
257+
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
258+
}
259+
251260
/// Snippet is ready!
252261
return { language, client: client as string, content: snippet };
253262
})
@@ -420,3 +429,48 @@ function indentString(str: string): string {
420429
function removeSuffix(str: string, suffix: string) {
421430
return str.endsWith(suffix) ? str.slice(0, -suffix.length) : str;
422431
}
432+
433+
function replaceAccessTokenPlaceholder(
434+
snippet: string,
435+
language: InferenceSnippetLanguage,
436+
provider: InferenceProviderOrPolicy
437+
): string {
438+
// If "accessToken" is empty, the snippets are generated with a placeholder.
439+
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
440+
441+
// Determine if HF_TOKEN or specific provider token should be used
442+
const accessTokenEnvVar =
443+
!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
444+
snippet.includes("https://router.huggingface.co") || // explicit routed request => use $HF_TOKEN
445+
provider == "hf-inference" // hf-inference provider => use $HF_TOKEN
446+
? "HF_TOKEN"
447+
: provider.toUpperCase().replace("-", "_") + "_API_TOKEN"; // e.g. "REPLICATE_API_TOKEN"
448+
449+
// Replace the placeholder with the env variable
450+
if (language === "sh") {
451+
snippet = snippet.replace(
452+
`'Authorization: Bearer ${ACCESS_TOKEN_PLACEHOLDER}'`,
453+
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
454+
);
455+
} else if (language === "python") {
456+
snippet = "import os\n" + snippet;
457+
snippet = snippet.replace(
458+
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
459+
`os.getenv("${accessTokenEnvVar}")` // e.g. os.getenv("HF_TOKEN")
460+
);
461+
snippet = snippet.replace(
462+
`"Bearer <ACCESS_TOKEN>"`,
463+
`f"Bearer {os.getenv('${accessTokenEnvVar}')}"` // e.g. f"Bearer {os.getenv('HF_TOKEN')}"
464+
);
465+
snippet = snippet.replace(
466+
`"Key <ACCESS_TOKEN>"`,
467+
`f"Key {os.getenv('${accessTokenEnvVar}')}"` // e.g. f"Key {os.getenv('FAL_AI_API_TOKEN')}"
468+
);
469+
} else if (language === "js") {
470+
snippet = snippet.replace(
471+
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
472+
`process.env.${accessTokenEnvVar}` // e.g. process.env.HF_TOKEN
473+
);
474+
}
475+
return snippet;
476+
}

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ function generateInferenceSnippet(
314314
): InferenceSnippet[] {
315315
const allSnippets = snippets.getInferenceSnippets(
316316
model,
317-
"api_token",
317+
"",
318318
provider,
319319
{
320320
hfModelId: model.id,

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/js/fetch/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ async function query(data) {
33
"https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo",
44
{
55
headers: {
6-
Authorization: "Bearer api_token",
6+
Authorization: "Bearer <ACCESS_TOKEN>",
77
"Content-Type": "audio/flac",
88
},
99
method: "POST",

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const data = fs.readFileSync("sample1.flac");
66

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.getenv("HF_TOKEN"),
67
)
78

89
output = client.automatic_speech_recognition("sample1.flac", model="openai/whisper-large-v3-turbo")

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/python/requests/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import requests
23

34
API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
45
headers = {
5-
"Authorization": "Bearer api_token",
6+
"Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
67
}
78

89
def query(filename):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
curl https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo \
22
-X POST \
3-
-H 'Authorization: Bearer api_token' \
3+
-H "Authorization: Bearer $HF_TOKEN" \
44
-H 'Content-Type: audio/flac' \
55
--data-binary @"sample1.flac"

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/js/fetch/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ async function query(data) {
33
"https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english",
44
{
55
headers: {
6-
Authorization: "Bearer api_token",
6+
Authorization: "Bearer <ACCESS_TOKEN>",
77
"Content-Type": "application/json",
88
},
99
method: "POST",

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const output = await client.tokenClassification({
66
model: "FacebookAI/xlm-roberta-large-finetuned-conll03-english",

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/python/huggingface_hub/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.getenv("HF_TOKEN"),
67
)
78

89
result = client.token_classification(

0 commit comments

Comments
 (0)