diff --git a/README.md b/README.md index 1774b8de57..845dde3b2f 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ await uploadFile({ } }); -// Use Inference API +// Use HF Inference API await inference.chatCompletion({ model: "meta-llama/Llama-3.1-8B-Instruct", @@ -53,7 +53,7 @@ await inference.textToImage({ This is a collection of JS libraries to interact with the Hugging Face API, with TS types included. -- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models +- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless), Inference Endpoints (dedicated) and third-party Inference providers to make calls to 100,000+ Machine Learning models - [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files - [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface - [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files. @@ -144,6 +144,22 @@ for await (const chunk of inference.chatCompletionStream({ console.log(chunk.choices[0].delta.content); } +/// Using a third-party provider: +await inference.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [{ role: "user", content: "Hello, nice to meet you!" }], + max_tokens: 512, + provider: "sambanova" +}) + +await inference.textToImage({ + model: "black-forest-labs/FLUX.1-dev", + inputs: "a picture of a green bird", + provider: "together" +}) + + + // You can also omit "model" to use the recommended model for the task await inference.translation({ inputs: "My name is Wolfgang and I live in Amsterdam", diff --git a/packages/inference/README.md b/packages/inference/README.md index 9c20c418b3..90ff498933 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -42,6 +42,34 @@ const hf = new HfInference('your access token') Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token. +### Requesting third-party inference providers + +You can request inference from third-party providers with the inference client. + +Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai). + +To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. +```ts +const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example) + +const client = new HfInference(accessToken); +await client.textToImage({ + provider: "replicate", + model:"black-forest-labs/Flux.1-dev", + inputs: "A black forest cake" +}) +``` + +When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co. +When authenticated with a third-party provider key, the request is made directly against that provider's inference API. + +Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here: +- [Fal.ai supported models](./src/providers/fal-ai.ts) +- [Replicate supported models](./src/providers/replicate.ts) +- [Sambanova supported models](./src/providers/sambanova.ts) +- [Together supported models](./src/providers/together.ts) +- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) + #### Tree-shaking You can import the functions you need directly from the module instead of using the `HfInference` class. diff --git a/packages/inference/src/index.ts b/packages/inference/src/index.ts index 3934a0493f..84255e8fae 100644 --- a/packages/inference/src/index.ts +++ b/packages/inference/src/index.ts @@ -1,3 +1,4 @@ +export type { ProviderMapping } from "./providers/types" export { HfInference, HfInferenceEndpoint } from "./HfInference"; export { InferenceOutputError } from "./lib/InferenceOutputError"; export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai"; diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 233297e1e2..f3798e2758 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,3 +1,4 @@ +import type { WidgetType } from "@huggingface/tasks"; import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config"; import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai"; import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate"; @@ -65,21 +66,21 @@ export async function makeRequestOptions( ? "hf-token" : "provider-key" : includeCredentials === "include" - ? "credentials-include" - : "none"; + ? "credentials-include" + : "none"; const url = endpointUrl ? chatCompletion ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({ - authMethod, - chatCompletion: chatCompletion ?? false, - forceTask, - model, - provider: provider ?? "hf-inference", - taskHint, - }); + authMethod, + chatCompletion: chatCompletion ?? false, + forceTask, + model, + provider: provider ?? "hf-inference", + taskHint, + }); const headers: Record = {}; if (accessToken) { @@ -133,9 +134,9 @@ export async function makeRequestOptions( body: binary ? args.data : JSON.stringify({ - ...otherArgs, - ...(chatCompletion || provider === "together" ? { model } : undefined), - }), + ...otherArgs, + ...(chatCompletion || provider === "together" ? { model } : undefined), + }), ...(credentials ? { credentials } : undefined), signal: options?.signal, }; @@ -155,7 +156,7 @@ function mapModel(params: { if (!params.taskHint) { throw new Error("taskHint must be specified when using a third-party provider"); } - const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; + const task: WidgetType = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; const model = (() => { switch (params.provider) { case "fal-ai": diff --git a/packages/inference/src/providers/types.ts b/packages/inference/src/providers/types.ts index c037b46ccf..a0335d9f83 100644 --- a/packages/inference/src/providers/types.ts +++ b/packages/inference/src/providers/types.ts @@ -1,5 +1,6 @@ -import type { InferenceTask, ModelId } from "../types"; +import type { WidgetType } from "@huggingface/tasks"; +import type { ModelId } from "../types"; export type ProviderMapping = Partial< - Record>> + Record>> >;