Skip to content

Commit 8aafe25

Browse files
committed
openai client compat
1 parent 7184d1b commit 8aafe25

File tree

4 files changed

+40
-35
lines changed

4 files changed

+40
-35
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;
22

33
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4+
5+
export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;

packages/tasks/src/snippets/curl.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { InferenceProvider } from "../inference-providers.js";
12
import type { PipelineType } from "../pipelines.js";
23
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
34
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
@@ -79,7 +80,7 @@ export const curlSnippets: Partial<
7980
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
8081
>
8182
> = {
82-
// Same order as in js/src/lib/interfaces/Types.ts
83+
// Same order as in tasks/src/pipelines.ts
8384
"text-classification": snippetBasic,
8485
"token-classification": snippetBasic,
8586
"table-question-answering": snippetBasic,
@@ -108,11 +109,14 @@ export const curlSnippets: Partial<
108109
export function getCurlInferenceSnippet(
109110
model: ModelDataMinimal,
110111
accessToken: string,
112+
provider: InferenceProvider,
111113
opts?: Record<string, unknown>
112-
): InferenceSnippet {
113-
return model.pipeline_tag && model.pipeline_tag in curlSnippets
114-
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
115-
: { content: "" };
114+
): InferenceSnippet[] {
115+
const snippets =
116+
model.pipeline_tag && model.pipeline_tag in curlSnippets
117+
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? [{ content: "" }]
118+
: [{ content: "" }];
119+
return Array.isArray(snippets) ? snippets : [snippets];
116120
}
117121

118122
export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {

packages/tasks/src/snippets/js.ts

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { InferenceProvider } from "../inference-providers.js";
1+
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
22
import type { PipelineType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
@@ -51,6 +51,11 @@ export const snippetTextGeneration = (
5151
top_p?: GenerationParameters["top_p"];
5252
}
5353
): InferenceSnippet[] => {
54+
const openAIbaseUrl =
55+
provider === "hf-inference"
56+
? "https://api-inference.huggingface.co/v1/"
57+
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
58+
5459
if (model.tags.includes("conversational")) {
5560
// Conversational model detected, so we display a code snippet that features the Messages API
5661
const streaming = opts?.streaming ?? true;
@@ -93,15 +98,13 @@ for await (const chunk of stream) {
9398
}
9499
}`,
95100
},
96-
...(provider === "hf-inference"
97-
? [
98-
{
99-
client: "openai",
100-
content: `import { OpenAI } from "openai";
101+
{
102+
client: "openai",
103+
content: `import { OpenAI } from "openai";
101104
102105
const client = new OpenAI({
103-
baseURL: "https://api-inference.huggingface.co/v1/",
104-
apiKey: "${accessToken || `{API_TOKEN}`}"
106+
baseURL: "${openAIbaseUrl}",
107+
apiKey: "${accessToken || `{API_TOKEN}`}"
105108
});
106109
107110
let out = "";
@@ -120,9 +123,7 @@ for await (const chunk of stream) {
120123
console.log(newContent);
121124
}
122125
}`,
123-
},
124-
]
125-
: []),
126+
},
126127
];
127128
} else {
128129
return [
@@ -141,15 +142,13 @@ const chatCompletion = await client.chatCompletion({
141142
142143
console.log(chatCompletion.choices[0].message);`,
143144
},
144-
...(provider === "hf-inference"
145-
? [
146-
{
147-
client: "openai",
148-
content: `import { OpenAI } from "openai";
145+
{
146+
client: "openai",
147+
content: `import { OpenAI } from "openai";
149148
150149
const client = new OpenAI({
151-
baseURL: "https://api-inference.huggingface.co/v1/",
152-
apiKey: "${accessToken || `{API_TOKEN}`}"
150+
baseURL: "${openAIbaseUrl}",
151+
apiKey: "${accessToken || `{API_TOKEN}`}"
153152
});
154153
155154
const chatCompletion = await client.chat.completions.create({
@@ -159,9 +158,7 @@ const chatCompletion = await client.chat.completions.create({
159158
});
160159
161160
console.log(chatCompletion.choices[0].message);`,
162-
},
163-
]
164-
: []),
161+
},
165162
];
166163
}
167164
} else {
@@ -227,9 +224,9 @@ infer(${getModelInputSnippet(model)}, { num_inference_steps: 5 }).then((image) =
227224
},
228225
...(provider === "hf-inference"
229226
? [
230-
{
231-
client: "fetch",
232-
content: `async function query(data) {
227+
{
228+
client: "fetch",
229+
content: `async function query(data) {
233230
const response = await fetch(
234231
"https://api-inference.huggingface.co/models/${model.id}",
235232
{
@@ -247,8 +244,8 @@ infer(${getModelInputSnippet(model)}, { num_inference_steps: 5 }).then((image) =
247244
query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
248245
// Use image
249246
});`,
250-
},
251-
]
247+
},
248+
]
252249
: []),
253250
];
254251
};
@@ -385,7 +382,7 @@ export const jsSnippets: Partial<
385382
) => InferenceSnippet[]
386383
>
387384
> = {
388-
// Same order as in src/pipelines.ts
385+
// Same order as in tasks/src/pipelines.ts
389386
"text-classification": snippetBasic,
390387
"token-classification": snippetBasic,
391388
"table-question-answering": snippetBasic,

packages/tasks/src/snippets/python.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { InferenceProvider } from "../inference-providers.js";
12
import type { PipelineType } from "../pipelines.js";
23
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
34
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
@@ -287,16 +288,17 @@ export const pythonSnippets: Partial<
287288
export function getPythonInferenceSnippet(
288289
model: ModelDataMinimal,
289290
accessToken: string,
291+
provider: InferenceProvider,
290292
opts?: Record<string, unknown>
291-
): InferenceSnippet | InferenceSnippet[] {
293+
): InferenceSnippet[] {
292294
if (model.tags.includes("conversational")) {
293295
// Conversational model detected, so we display a code snippet that features the Messages API
294296
return snippetConversational(model, accessToken, opts);
295297
} else {
296298
let snippets =
297299
model.pipeline_tag && model.pipeline_tag in pythonSnippets
298-
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
299-
: { content: "" };
300+
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? [{ content: "" }]
301+
: [{ content: "" }];
300302

301303
snippets = Array.isArray(snippets) ? snippets : [snippets];
302304

0 commit comments

Comments
 (0)