Skip to content

Commit aeb0731

Browse files
committed
curl snippets
1 parent a941c0b commit aeb0731

File tree

3 files changed

+71
-27
lines changed

3 files changed

+71
-27
lines changed

packages/tasks/src/inference-providers.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "samb
22

33
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
44

5-
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
5+
export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;
66

77
/**
88
* URL to set as baseUrl in the OpenAI SDK.

packages/tasks/src/snippets/curl.ts

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,47 @@
1-
import type { InferenceProvider } from "../inference-providers.js";
1+
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
22
import type { PipelineType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
55
import { getModelInputSnippet } from "./inputs.js";
66
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
77

8-
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
9-
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
10-
-X POST \\
11-
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
12-
-H 'Content-Type: application/json' \\
13-
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
14-
});
8+
export const snippetBasic = (
9+
model: ModelDataMinimal,
10+
accessToken: string,
11+
provider: InferenceProvider
12+
): InferenceSnippet[] => {
13+
if (provider !== "hf-inference") {
14+
return [];
15+
}
16+
return [
17+
{
18+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
19+
-X POST \\
20+
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
21+
-H 'Content-Type: application/json' \\
22+
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
23+
},
24+
];
25+
};
1526

1627
export const snippetTextGeneration = (
1728
model: ModelDataMinimal,
1829
accessToken: string,
30+
provider: InferenceProvider,
1931
opts?: {
2032
streaming?: boolean;
2133
messages?: ChatCompletionInputMessage[];
2234
temperature?: GenerationParameters["temperature"];
2335
max_tokens?: GenerationParameters["max_tokens"];
2436
top_p?: GenerationParameters["top_p"];
2537
}
26-
): InferenceSnippet => {
38+
): InferenceSnippet[] => {
2739
if (model.tags.includes("conversational")) {
40+
const baseUrl =
41+
provider === "hf-inference"
42+
? `https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions`
43+
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions";
44+
2845
// Conversational model detected, so we display a code snippet that features the Messages API
2946
const streaming = opts?.streaming ?? true;
3047
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
@@ -35,8 +52,9 @@ export const snippetTextGeneration = (
3552
max_tokens: opts?.max_tokens ?? 500,
3653
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
3754
};
38-
return {
39-
content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
55+
return [
56+
{
57+
content: `curl '${baseUrl}' \\
4058
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
4159
-H 'Content-Type: application/json' \\
4260
--data '{
@@ -53,31 +71,59 @@ export const snippetTextGeneration = (
5371
})},
5472
"stream": ${!!streaming}
5573
}'`,
56-
};
74+
},
75+
];
5776
} else {
58-
return snippetBasic(model, accessToken);
77+
return snippetBasic(model, accessToken, provider);
5978
}
6079
};
6180

62-
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
63-
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
81+
export const snippetZeroShotClassification = (
82+
model: ModelDataMinimal,
83+
accessToken: string,
84+
provider: InferenceProvider
85+
): InferenceSnippet[] => {
86+
if (provider !== "hf-inference") {
87+
return [];
88+
}
89+
return [
90+
{
91+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
6492
-X POST \\
6593
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
6694
-H 'Content-Type: application/json' \\
6795
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
68-
});
96+
},
97+
];
98+
};
6999

70-
export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
71-
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
100+
export const snippetFile = (
101+
model: ModelDataMinimal,
102+
accessToken: string,
103+
provider: InferenceProvider
104+
): InferenceSnippet[] => {
105+
if (provider !== "hf-inference") {
106+
return [];
107+
}
108+
return [
109+
{
110+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
72111
-X POST \\
73112
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
74113
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
75-
});
114+
},
115+
];
116+
};
76117

77118
export const curlSnippets: Partial<
78119
Record<
79120
PipelineType,
80-
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
121+
(
122+
model: ModelDataMinimal,
123+
accessToken: string,
124+
provider: InferenceProvider,
125+
opts?: Record<string, unknown>
126+
) => InferenceSnippet[]
81127
>
82128
> = {
83129
// Same order as in tasks/src/pipelines.ts
@@ -112,11 +158,9 @@ export function getCurlInferenceSnippet(
112158
provider: InferenceProvider,
113159
opts?: Record<string, unknown>
114160
): InferenceSnippet[] {
115-
const snippets =
116-
model.pipeline_tag && model.pipeline_tag in curlSnippets
117-
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? [{ content: "" }]
118-
: [{ content: "" }];
119-
return Array.isArray(snippets) ? snippets : [snippets];
161+
return model.pipeline_tag && model.pipeline_tag in curlSnippets
162+
? curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? []
163+
: [];
120164
}
121165

122166
export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {

packages/tasks/src/snippets/js.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
1+
import { openAIbaseUrl, type InferenceProvider } from "../inference-providers.js";
22
import type { PipelineType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";

0 commit comments

Comments
 (0)