Skip to content

Commit d515d60

Browse files
julien-cSBrandeis
andauthored
[tasks] in inference snippets, provider does not need to be strongly typed (#1184)
will make things easier when adding new providers (one less moving part) --------- Co-authored-by: SBrandeis <[email protected]>
1 parent b6b7fc4 commit d515d60

File tree

6 files changed

+82
-30
lines changed

6 files changed

+82
-30
lines changed

packages/inference/test/tapes.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6789,5 +6789,45 @@
67896789
"server": "UploadServer"
67906790
}
67916791
}
6792+
},
6793+
"efa2b5ab7171e43629fef33886a32583919f4dfe814ae07a44db19257ee123ae": {
6794+
"url": "https://fal.run/fal-ai/fast-sdxl",
6795+
"init": {
6796+
"headers": {
6797+
"Content-Type": "application/json"
6798+
},
6799+
"method": "POST",
6800+
"body": "{\"response_format\":\"base64\",\"prompt\":\"Extreme close-up of a single tiger eye, direct frontal view. Detailed iris and pupil. Sharp focus on eye texture and color. Natural lighting to capture authentic eye shine and depth.\"}"
6801+
},
6802+
"response": {
6803+
"body": "{\"images\":[{\"url\":\"https://fal.media/files/monkey/t28MYvYK21vq9nIypBm0P.jpeg\",\"width\":1024,\"height\":1024,\"content_type\":\"image/jpeg\"}],\"timings\":{\"inference\":2.1236871778964996},\"seed\":15619174981588513000,\"has_nsfw_concepts\":[false],\"prompt\":\"Extreme close-up of a single tiger eye, direct frontal view. Detailed iris and pupil. Sharp focus on eye texture and color. Natural lighting to capture authentic eye shine and depth.\"}",
6804+
"status": 200,
6805+
"statusText": "OK",
6806+
"headers": {
6807+
"connection": "keep-alive",
6808+
"content-type": "application/json",
6809+
"strict-transport-security": "max-age=31536000; includeSubDomains"
6810+
}
6811+
}
6812+
},
6813+
"374890ec5b45788656310c21999957168f47242bd379c91da86d00eab7b9b218": {
6814+
"url": "https://fal.media/files/monkey/t28MYvYK21vq9nIypBm0P.jpeg",
6815+
"init": {},
6816+
"response": {
6817+
"body": "",
6818+
"status": 200,
6819+
"statusText": "OK",
6820+
"headers": {
6821+
"access-control-allow-headers": "*",
6822+
"access-control-allow-methods": "*",
6823+
"access-control-allow-origin": "*",
6824+
"access-control-max-age": "86400",
6825+
"cf-ray": "90d404087de9999f-CDG",
6826+
"connection": "keep-alive",
6827+
"content-type": "image/jpeg",
6828+
"server": "cloudflare",
6829+
"vary": "Accept-Encoding"
6830+
}
6831+
}
67926832
}
67936833
}

packages/inference/test/vcr.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async function vcr(
117117

118118
const { default: tapes } = await import(TAPES_FILE);
119119

120-
const cacheCandidate = !url.startsWith(HF_HUB_URL) || url.startsWith("https://huggingface.co/api/inference-proxy/");
120+
const cacheCandidate = !url.startsWith(HF_HUB_URL);
121121

122122
if (VCR_MODE === MODE.PLAYBACK && cacheCandidate) {
123123
if (!tapes[hash]) {
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;
1+
/// This list is for illustration purposes only.
2+
/// in the `tasks` sub-package, we do not need actual strong typing of the inference providers.
3+
const INFERENCE_PROVIDERS = [
4+
"fal-ai",
5+
"fireworks-ai",
6+
"hf-inference",
7+
"hyperbolic",
8+
"replicate",
9+
"sambanova",
10+
"together",
11+
] as const;
212

3-
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
13+
export type SnippetInferenceProvider = (typeof INFERENCE_PROVIDERS)[number] | string;
414

5-
export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
15+
export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://router.huggingface.co/{{PROVIDER}}`;
616

717
/**
818
* URL to set as baseUrl in the OpenAI SDK.
919
*
1020
* TODO(Expose this from HfInference in the future?)
1121
*/
12-
export function openAIbaseUrl(provider: InferenceProvider): string {
13-
return provider === "hf-inference"
14-
? "https://api-inference.huggingface.co/v1/"
15-
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
22+
export function openAIbaseUrl(provider: SnippetInferenceProvider): string {
23+
return HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
1624
}

packages/tasks/src/snippets/curl.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
1+
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type SnippetInferenceProvider } from "../inference-providers.js";
22
import type { PipelineType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
@@ -8,7 +8,7 @@ import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
88
export const snippetBasic = (
99
model: ModelDataMinimal,
1010
accessToken: string,
11-
provider: InferenceProvider
11+
provider: SnippetInferenceProvider
1212
): InferenceSnippet[] => {
1313
if (provider !== "hf-inference") {
1414
return [];
@@ -29,7 +29,7 @@ curl https://api-inference.huggingface.co/models/${model.id} \\
2929
export const snippetTextGeneration = (
3030
model: ModelDataMinimal,
3131
accessToken: string,
32-
provider: InferenceProvider,
32+
provider: SnippetInferenceProvider,
3333
opts?: {
3434
streaming?: boolean;
3535
messages?: ChatCompletionInputMessage[];
@@ -84,7 +84,7 @@ export const snippetTextGeneration = (
8484
export const snippetZeroShotClassification = (
8585
model: ModelDataMinimal,
8686
accessToken: string,
87-
provider: InferenceProvider
87+
provider: SnippetInferenceProvider
8888
): InferenceSnippet[] => {
8989
if (provider !== "hf-inference") {
9090
return [];
@@ -104,7 +104,7 @@ export const snippetZeroShotClassification = (
104104
export const snippetFile = (
105105
model: ModelDataMinimal,
106106
accessToken: string,
107-
provider: InferenceProvider
107+
provider: SnippetInferenceProvider
108108
): InferenceSnippet[] => {
109109
if (provider !== "hf-inference") {
110110
return [];
@@ -126,7 +126,7 @@ export const curlSnippets: Partial<
126126
(
127127
model: ModelDataMinimal,
128128
accessToken: string,
129-
provider: InferenceProvider,
129+
provider: SnippetInferenceProvider,
130130
opts?: Record<string, unknown>
131131
) => InferenceSnippet[]
132132
>
@@ -160,7 +160,7 @@ export const curlSnippets: Partial<
160160
export function getCurlInferenceSnippet(
161161
model: ModelDataMinimal,
162162
accessToken: string,
163-
provider: InferenceProvider,
163+
provider: SnippetInferenceProvider,
164164
opts?: Record<string, unknown>
165165
): InferenceSnippet[] {
166166
return model.pipeline_tag && model.pipeline_tag in curlSnippets

packages/tasks/src/snippets/js.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
1+
import { openAIbaseUrl, type SnippetInferenceProvider } from "../inference-providers.js";
22
import type { PipelineType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
@@ -22,7 +22,7 @@ const HFJS_METHODS: Record<string, string> = {
2222
export const snippetBasic = (
2323
model: ModelDataMinimal,
2424
accessToken: string,
25-
provider: InferenceProvider
25+
provider: SnippetInferenceProvider
2626
): InferenceSnippet[] => {
2727
return [
2828
...(model.pipeline_tag && model.pipeline_tag in HFJS_METHODS
@@ -74,7 +74,7 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
7474
export const snippetTextGeneration = (
7575
model: ModelDataMinimal,
7676
accessToken: string,
77-
provider: InferenceProvider,
77+
provider: SnippetInferenceProvider,
7878
opts?: {
7979
streaming?: boolean;
8080
messages?: ChatCompletionInputMessage[];
@@ -225,7 +225,7 @@ export const snippetZeroShotClassification = (model: ModelDataMinimal, accessTok
225225
export const snippetTextToImage = (
226226
model: ModelDataMinimal,
227227
accessToken: string,
228-
provider: InferenceProvider
228+
provider: SnippetInferenceProvider
229229
): InferenceSnippet[] => {
230230
return [
231231
{
@@ -275,7 +275,7 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
275275
export const snippetTextToAudio = (
276276
model: ModelDataMinimal,
277277
accessToken: string,
278-
provider: InferenceProvider
278+
provider: SnippetInferenceProvider
279279
): InferenceSnippet[] => {
280280
if (provider !== "hf-inference") {
281281
return [];
@@ -329,7 +329,7 @@ export const snippetTextToAudio = (
329329
export const snippetAutomaticSpeechRecognition = (
330330
model: ModelDataMinimal,
331331
accessToken: string,
332-
provider: InferenceProvider
332+
provider: SnippetInferenceProvider
333333
): InferenceSnippet[] => {
334334
return [
335335
{
@@ -357,7 +357,7 @@ console.log(output);
357357
export const snippetFile = (
358358
model: ModelDataMinimal,
359359
accessToken: string,
360-
provider: InferenceProvider
360+
provider: SnippetInferenceProvider
361361
): InferenceSnippet[] => {
362362
if (provider !== "hf-inference") {
363363
return [];
@@ -395,7 +395,7 @@ export const jsSnippets: Partial<
395395
(
396396
model: ModelDataMinimal,
397397
accessToken: string,
398-
provider: InferenceProvider,
398+
provider: SnippetInferenceProvider,
399399
opts?: Record<string, unknown>
400400
) => InferenceSnippet[]
401401
>
@@ -429,7 +429,7 @@ export const jsSnippets: Partial<
429429
export function getJsInferenceSnippet(
430430
model: ModelDataMinimal,
431431
accessToken: string,
432-
provider: InferenceProvider,
432+
provider: SnippetInferenceProvider,
433433
opts?: Record<string, unknown>
434434
): InferenceSnippet[] {
435435
return model.pipeline_tag && model.pipeline_tag in jsSnippets

packages/tasks/src/snippets/python.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
1+
import {
2+
HF_HUB_INFERENCE_PROXY_TEMPLATE,
3+
openAIbaseUrl,
4+
type SnippetInferenceProvider,
5+
} from "../inference-providers.js";
26
import type { PipelineType } from "../pipelines.js";
37
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
48
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
59
import { getModelInputSnippet } from "./inputs.js";
610
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
711

8-
const snippetImportInferenceClient = (accessToken: string, provider: InferenceProvider): string =>
12+
const snippetImportInferenceClient = (accessToken: string, provider: SnippetInferenceProvider): string =>
913
`\
1014
from huggingface_hub import InferenceClient
1115
@@ -17,7 +21,7 @@ client = InferenceClient(
1721
export const snippetConversational = (
1822
model: ModelDataMinimal,
1923
accessToken: string,
20-
provider: InferenceProvider,
24+
provider: SnippetInferenceProvider,
2125
opts?: {
2226
streaming?: boolean;
2327
messages?: ChatCompletionInputMessage[];
@@ -199,7 +203,7 @@ output = query(${getModelInputSnippet(model)})`,
199203
export const snippetTextToImage = (
200204
model: ModelDataMinimal,
201205
accessToken: string,
202-
provider: InferenceProvider
206+
provider: SnippetInferenceProvider
203207
): InferenceSnippet[] => {
204208
return [
205209
{
@@ -337,7 +341,7 @@ export const pythonSnippets: Partial<
337341
(
338342
model: ModelDataMinimal,
339343
accessToken: string,
340-
provider: InferenceProvider,
344+
provider: SnippetInferenceProvider,
341345
opts?: Record<string, unknown>
342346
) => InferenceSnippet[]
343347
>
@@ -375,7 +379,7 @@ export const pythonSnippets: Partial<
375379
export function getPythonInferenceSnippet(
376380
model: ModelDataMinimal,
377381
accessToken: string,
378-
provider: InferenceProvider,
382+
provider: SnippetInferenceProvider,
379383
opts?: Record<string, unknown>
380384
): InferenceSnippet[] {
381385
if (model.tags.includes("conversational")) {

0 commit comments

Comments
 (0)