Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ await uploadFile({
}
});

// Use Inference API
// Use HF Inference API

await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
Expand All @@ -53,7 +53,7 @@ await inference.textToImage({

This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.

- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless), Inference Endpoints (dedicated) and third-party Inference providers to make calls to 100,000+ Machine Learning models
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
Expand Down Expand Up @@ -144,6 +144,22 @@ for await (const chunk of inference.chatCompletionStream({
console.log(chunk.choices[0].delta.content);
}

/// Using a third-party provider:
await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
max_tokens: 512,
provider: "sambanova"
})

await inference.textToImage({
model: "black-forest-labs/FLUX.1-dev",
inputs: "a picture of a green bird",
provider: "together"
})



// You can also omit "model" to use the recommended model for the task
await inference.translation({
inputs: "My name is Wolfgang and I live in Amsterdam",
Expand Down
28 changes: 28 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ const hf = new HfInference('your access token')

Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.

### Requesting third-party inference providers

You can request inference from third-party providers with the inference client.

Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).

To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example)

const client = new HfInference(accessToken);
await client.textToImage({
provider: "replicate",
model:"black-forest-labs/Flux.1-dev",
inputs: "A black forest cake"
})
```

When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co.
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.

Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here:
- [Fal.ai supported models](./src/providers/fal-ai.ts)
- [Replicate supported models](./src/providers/replicate.ts)
- [Sambanova supported models](./src/providers/sambanova.ts)
- [Together supported models](./src/providers/together.ts)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

#### Tree-shaking

You can import the functions you need directly from the module instead of using the `HfInference` class.
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export type { ProviderMapping } from "./providers/types"
export { HfInference, HfInferenceEndpoint } from "./HfInference";
export { InferenceOutputError } from "./lib/InferenceOutputError";
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
Expand Down
27 changes: 14 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
Expand Down Expand Up @@ -65,21 +66,21 @@ export async function makeRequestOptions(
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
? "credentials-include"
: "none";

const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
});
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
});

const headers: Record<string, string> = {};
if (accessToken) {
Expand Down Expand Up @@ -133,9 +134,9 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};
Expand All @@ -155,7 +156,7 @@ function mapModel(params: {
if (!params.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const task: WidgetType = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {
switch (params.provider) {
case "fal-ai":
Expand Down
5 changes: 3 additions & 2 deletions packages/inference/src/providers/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { InferenceTask, ModelId } from "../types";
import type { WidgetType } from "@huggingface/tasks";
import type { ModelId } from "../types";

export type ProviderMapping<ProviderId extends string> = Partial<
Record<InferenceTask | "conversational", Partial<Record<ModelId, ProviderId>>>
Record<WidgetType, Partial<Record<ModelId, ProviderId>>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

>;
Loading