Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion packages/agents/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 🤗 Hugging Face Inference Endpoints
# 🤗 Hugging Face Inference

A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers.
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers.
Expand Down
2 changes: 2 additions & 0 deletions packages/tasks/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@ export type { LocalApp, LocalAppKey, LocalAppSnippet } from "./local-apps.js";

export { DATASET_LIBRARIES_UI_ELEMENTS } from "./dataset-libraries.js";
export type { DatasetLibraryUiElement, DatasetLibraryKey } from "./dataset-libraries.js";

export * from "./inference-providers.js";
14 changes: 14 additions & 0 deletions packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export const INFERENCE_PROVIDERS = ["hf-inference", "fal-ai", "replicate", "sambanova", "together"] as const;

export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export const HF_HUB_INFERENCE_PROXY_TEMPLATE = `https://huggingface.co/api/inference-proxy/{{PROVIDER}}`;

/**
* URL to set as baseUrl in the OpenAI SDK.
*/
export function openAIbaseUrl(provider: InferenceProvider): string {
return provider === "hf-inference"
? "https://api-inference.huggingface.co/v1/"
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
}
94 changes: 71 additions & 23 deletions packages/tasks/src/snippets/curl.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,47 @@
import { HF_HUB_INFERENCE_PROXY_TEMPLATE, type InferenceProvider } from "../inference-providers.js";
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H 'Content-Type: application/json' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
});
export const snippetBasic = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
-H 'Content-Type: application/json' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
},
];
};

export const snippetTextGeneration = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider,
opts?: {
streaming?: boolean;
messages?: ChatCompletionInputMessage[];
temperature?: GenerationParameters["temperature"];
max_tokens?: GenerationParameters["max_tokens"];
top_p?: GenerationParameters["top_p"];
}
): InferenceSnippet => {
): InferenceSnippet[] => {
if (model.tags.includes("conversational")) {
const baseUrl =
provider === "hf-inference"
? `https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions`
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider) + "/v1/chat/completions";

// Conversational model detected, so we display a code snippet that features the Messages API
const streaming = opts?.streaming ?? true;
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
Expand All @@ -34,8 +52,9 @@ export const snippetTextGeneration = (
max_tokens: opts?.max_tokens ?? 500,
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
};
return {
content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
return [
{
content: `curl '${baseUrl}' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}' \\
-H 'Content-Type: application/json' \\
--data '{
Expand All @@ -52,34 +71,62 @@ export const snippetTextGeneration = (
})},
"stream": ${!!streaming}
}'`,
};
},
];
} else {
return snippetBasic(model, accessToken);
return snippetBasic(model, accessToken, provider);
}
};

export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
export const snippetZeroShotClassification = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
-H 'Content-Type: application/json' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
});
},
];
};

export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
export const snippetFile = (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider
): InferenceSnippet[] => {
if (provider !== "hf-inference") {
return [];
}
return [
{
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
-H 'Authorization: Bearer ${accessToken || `{API_TOKEN}`}'`,
});
},
];
};

export const curlSnippets: Partial<
Record<
PipelineType,
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
(
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider,
opts?: Record<string, unknown>
) => InferenceSnippet[]
>
> = {
// Same order as in js/src/lib/interfaces/Types.ts
// Same order as in tasks/src/pipelines.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
Expand Down Expand Up @@ -108,11 +155,12 @@ export const curlSnippets: Partial<
export function getCurlInferenceSnippet(
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProvider,
opts?: Record<string, unknown>
): InferenceSnippet {
): InferenceSnippet[] {
return model.pipeline_tag && model.pipeline_tag in curlSnippets
? curlSnippets[model.pipeline_tag]?.(model, accessToken, opts) ?? { content: "" }
: { content: "" };
? curlSnippets[model.pipeline_tag]?.(model, accessToken, provider, opts) ?? [{ content: "" }]
: [{ content: "" }];
}

export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {
Expand Down
Loading
Loading