1- import type { InferenceProvider } from "../inference-providers.js" ;
1+ import { HF_HUB_INFERENCE_PROXY_TEMPLATE , type InferenceProvider } from "../inference-providers.js" ;
22import type { PipelineType } from "../pipelines.js" ;
33import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
44import { stringifyGenerationConfig , stringifyMessages } from "./common.js" ;
@@ -51,6 +51,11 @@ export const snippetTextGeneration = (
5151 top_p ?: GenerationParameters [ "top_p" ] ;
5252 }
5353) : InferenceSnippet [ ] => {
54+ const openAIbaseUrl =
55+ provider === "hf-inference"
56+ ? "https://api-inference.huggingface.co/v1/"
57+ : HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider ) ;
58+
5459 if ( model . tags . includes ( "conversational" ) ) {
5560 // Conversational model detected, so we display a code snippet that features the Messages API
5661 const streaming = opts ?. streaming ?? true ;
@@ -93,15 +98,13 @@ for await (const chunk of stream) {
9398 }
9499}` ,
95100 } ,
96- ...( provider === "hf-inference"
97- ? [
98- {
99- client : "openai" ,
100- content : `import { OpenAI } from "openai";
101+ {
102+ client : "openai" ,
103+ content : `import { OpenAI } from "openai";
101104
102105const client = new OpenAI({
103- baseURL: "https://api-inference.huggingface.co/v1/ ",
104- apiKey: "${ accessToken || `{API_TOKEN}` } "
106+ baseURL: "${ openAIbaseUrl } ",
107+ apiKey: "${ accessToken || `{API_TOKEN}` } "
105108});
106109
107110let out = "";
@@ -120,9 +123,7 @@ for await (const chunk of stream) {
120123 console.log(newContent);
121124 }
122125}` ,
123- } ,
124- ]
125- : [ ] ) ,
126+ } ,
126127 ] ;
127128 } else {
128129 return [
@@ -141,15 +142,13 @@ const chatCompletion = await client.chatCompletion({
141142
142143console.log(chatCompletion.choices[0].message);` ,
143144 } ,
144- ...( provider === "hf-inference"
145- ? [
146- {
147- client : "openai" ,
148- content : `import { OpenAI } from "openai";
145+ {
146+ client : "openai" ,
147+ content : `import { OpenAI } from "openai";
149148
150149const client = new OpenAI({
151- baseURL: "https://api-inference.huggingface.co/v1/ ",
152- apiKey: "${ accessToken || `{API_TOKEN}` } "
150+ baseURL: "${ openAIbaseUrl } ",
151+ apiKey: "${ accessToken || `{API_TOKEN}` } "
153152});
154153
155154const chatCompletion = await client.chat.completions.create({
@@ -159,9 +158,7 @@ const chatCompletion = await client.chat.completions.create({
159158});
160159
161160console.log(chatCompletion.choices[0].message);` ,
162- } ,
163- ]
164- : [ ] ) ,
161+ } ,
165162 ] ;
166163 }
167164 } else {
@@ -227,9 +224,9 @@ infer(${getModelInputSnippet(model)}, { num_inference_steps: 5 }).then((image) =
227224 } ,
228225 ...( provider === "hf-inference"
229226 ? [
230- {
231- client : "fetch" ,
232- content : `async function query(data) {
227+ {
228+ client : "fetch" ,
229+ content : `async function query(data) {
233230 const response = await fetch(
234231 "https://api-inference.huggingface.co/models/${ model . id } ",
235232 {
@@ -247,8 +244,8 @@ infer(${getModelInputSnippet(model)}, { num_inference_steps: 5 }).then((image) =
247244query({"inputs": ${ getModelInputSnippet ( model ) } }).then((response) => {
248245 // Use image
249246});` ,
250- } ,
251- ]
247+ } ,
248+ ]
252249 : [ ] ) ,
253250 ] ;
254251} ;
@@ -385,7 +382,7 @@ export const jsSnippets: Partial<
385382 ) => InferenceSnippet [ ]
386383 >
387384> = {
388- // Same order as in src/pipelines.ts
385+ // Same order as in tasks/ src/pipelines.ts
389386 "text-classification" : snippetBasic ,
390387 "token-classification" : snippetBasic ,
391388 "table-question-answering" : snippetBasic ,
0 commit comments