Skip to content

Commit 62137bd

Browse files
authored
[InferenceSnippets] Add endpointUrl option (#1521)
Add an option to pass a custom endpoint URL for inference snippets generation. ```ts snippets.getInferenceSnippets( model, selectedProvider, { hfModelId: ..., providerId: ..., status: "live", task: pipeline, }, { streaming, endpointURL: "http://localhost:8080/v1", } ) ``` cc @gary149
1 parent a444bd0 commit 62137bd

File tree

18 files changed

+176
-12
lines changed

18 files changed

+176
-12
lines changed

packages/inference/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,10 @@ You can use any Chat Completion API-compatible provider with the `chatCompletion
651651
```typescript
652652
// Chat Completion Example
653653
const MISTRAL_KEY = process.env.MISTRAL_KEY;
654-
const hf = new InferenceClient(MISTRAL_KEY);
655-
const ep = hf.endpoint("https://api.mistral.ai");
656-
const stream = ep.chatCompletionStream({
654+
const hf = new InferenceClient(MISTRAL_KEY, {
655+
endpointUrl: "https://api.mistral.ai",
656+
});
657+
const stream = hf.chatCompletionStream({
657658
model: "mistral-tiny",
658659
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
659660
});

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ export type InferenceSnippetOptions = {
1818
streaming?: boolean;
1919
billTo?: string;
2020
accessToken?: string;
21-
directRequest?: boolean;
21+
directRequest?: boolean; // to bypass HF routing and call the provider directly
22+
endpointUrl?: string; // to call a local endpoint directly
2223
} & Record<string, unknown>;
2324

2425
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
@@ -53,6 +54,7 @@ interface TemplateParams {
5354
methodName?: string; // specific to snippetBasic
5455
importBase64?: boolean; // specific to snippetImportRequests
5556
importJson?: boolean; // specific to snippetImportRequests
57+
endpointUrl?: string;
5658
}
5759

5860
// Helpers to find + load templates
@@ -172,6 +174,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
172174
{
173175
accessToken: accessTokenOrPlaceholder,
174176
provider,
177+
endpointUrl: opts?.endpointUrl,
175178
...inputs,
176179
} as RequestArgs,
177180
inferenceProviderMapping,
@@ -217,6 +220,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
217220
provider,
218221
providerModelId: providerModelId ?? model.id,
219222
billTo: opts?.billTo,
223+
endpointUrl: opts?.endpointUrl,
220224
};
221225

222226
/// Iterate over clients => check if a snippet exists => generate
@@ -265,7 +269,14 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
265269

266270
/// Replace access token placeholder
267271
if (snippet.includes(placeholder)) {
268-
snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider);
272+
snippet = replaceAccessTokenPlaceholder(
273+
opts?.directRequest,
274+
placeholder,
275+
snippet,
276+
language,
277+
provider,
278+
opts?.endpointUrl
279+
);
269280
}
270281

271282
/// Snippet is ready!
@@ -444,21 +455,24 @@ function replaceAccessTokenPlaceholder(
444455
placeholder: string,
445456
snippet: string,
446457
language: InferenceSnippetLanguage,
447-
provider: InferenceProviderOrPolicy
458+
provider: InferenceProviderOrPolicy,
459+
endpointUrl?: string
448460
): string {
449461
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
450462
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
451463

452464
// Determine if HF_TOKEN or specific provider token should be used
453465
const useHfToken =
454-
provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
455-
(!directRequest && // if explicit directRequest => use provider-specific token
456-
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
457-
snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN
458-
466+
!endpointUrl && // custom endpointUrl => use a generic API_TOKEN
467+
(provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
468+
(!directRequest && // if explicit directRequest => use provider-specific token
469+
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
470+
snippet.includes("https://router.huggingface.co")))); // explicit routed request => use $HF_TOKEN
459471
const accessTokenEnvVar = useHfToken
460472
? "HF_TOKEN" // e.g. routed request or hf-inference
461-
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
473+
: endpointUrl
474+
? "API_TOKEN"
475+
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
462476

463477
// Replace the placeholder with the env variable
464478
if (language === "sh") {

packages/inference/src/snippets/templates/js/huggingface.js/basic.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const output = await client.{{ methodName }}({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
model: "{{ model.id }}",
710
inputs: {{ inputs.asObj.inputs }},
811
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/basicAudio.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
const data = fs.readFileSync({{inputs.asObj.inputs}});
66

77
const output = await client.{{ methodName }}({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
data,
912
model: "{{ model.id }}",
1013
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/basicImage.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
const data = fs.readFileSync({{inputs.asObj.inputs}});
66

77
const output = await client.{{ methodName }}({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
data,
912
model: "{{ model.id }}",
1013
provider: "{{ provider }}",

packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const chatCompletion = await client.chatCompletion({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
{{ inputs.asTsString }}

packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ const client = new InferenceClient("{{ accessToken }}");
55
let out = "";
66

77
const stream = client.chatCompletionStream({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
811
provider: "{{ provider }}",
912
model: "{{ model.id }}",
1013
{{ inputs.asTsString }}

packages/inference/src/snippets/templates/js/huggingface.js/textToImage.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const image = await client.textToImage({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

packages/inference/src/snippets/templates/js/huggingface.js/textToSpeech.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const audio = await client.textToSpeech({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

packages/inference/src/snippets/templates/js/huggingface.js/textToVideo.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient("{{ accessToken }}");
44

55
const video = await client.textToVideo({
6+
{% if endpointUrl %}
7+
endpointUrl: "{{ endpointUrl }}",
8+
{% endif %}
69
provider: "{{ provider }}",
710
model: "{{ model.id }}",
811
inputs: {{ inputs.asObj.inputs }},

0 commit comments

Comments
 (0)