Skip to content

Commit 9ef486b

Browse files
committed
Merge branch 'main' into kai/hyperbolic-integration
2 parents cba9580 + 5a394d2 commit 9ef486b

File tree

14 files changed

+470
-160
lines changed

14 files changed

+470
-160
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ jobs:
4141
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
44+
HF_BLACK_FOREST_LABS_KEY: dummy
4445
HF_FAL_KEY: dummy
4546
HF_FIREWORKS_KEY: dummy
4647
HF_HYPERBOLIC_KEY: dummy
4748
HF_NEBIUS_KEY: dummy
49+
HF_NOVITA_KEY: dummy
4850
HF_REPLICATE_KEY: dummy
4951
HF_SAMBANOVA_KEY: dummy
5052
HF_TOGETHER_KEY: dummy
53+
5154
browser:
5255
runs-on: ubuntu-latest
5356
timeout-minutes: 10
@@ -83,13 +86,16 @@ jobs:
8386
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
8487
env:
8588
HF_TOKEN: ${{ secrets.HF_TOKEN }}
89+
HF_BLACK_FOREST_LABS_KEY: dummy
8690
HF_FAL_KEY: dummy
8791
HF_FIREWORKS_KEY: dummy
8892
HF_HYPERBOLIC_KEY: dummy
8993
HF_NEBIUS_KEY: dummy
94+
HF_NOVITA_KEY: dummy
9095
HF_REPLICATE_KEY: dummy
9196
HF_SAMBANOVA_KEY: dummy
9297
HF_TOGETHER_KEY: dummy
98+
9399
e2e:
94100
runs-on: ubuntu-latest
95101
timeout-minutes: 10
@@ -152,10 +158,12 @@ jobs:
152158
env:
153159
NPM_CONFIG_REGISTRY: http://localhost:4874/
154160
HF_TOKEN: ${{ secrets.HF_TOKEN }}
161+
HF_BLACK_FOREST_LABS_KEY: dummy
155162
HF_FAL_KEY: dummy
156163
HF_FIREWORKS_KEY: dummy
157164
HF_HYPERBOLIC_KEY: dummy
158165
HF_NEBIUS_KEY: dummy
166+
HF_NOVITA_KEY: dummy
159167
HF_REPLICATE_KEY: dummy
160168
HF_SAMBANOVA_KEY: dummy
161169
HF_TOGETHER_KEY: dummy

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ Currently, we support the following providers:
5151
- [Fireworks AI](https://fireworks.ai)
5252
- [Hyperbolic](https://hyperbolic.xyz)
5353
- [Nebius](https://studio.nebius.ai)
54+
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
5455
- [Replicate](https://replicate.com)
5556
- [Sambanova](https://sambanova.ai)
5657
- [Together](https://together.xyz)
58+
- [Blackforestlabs](https://blackforestlabs.ai)
5759

5860
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
5961
```ts

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import { NEBIUS_API_BASE_URL } from "../providers/nebius";
44
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
55
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL } from "../providers/together";
7+
import { NOVITA_API_BASE_URL } from "../providers/novita";
78
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
89
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
10+
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
911
import type { InferenceProvider } from "../types";
1012
import type { InferenceTask, Options, RequestArgs } from "../types";
1113
import { isUrl } from "./isUrl";
@@ -29,8 +31,6 @@ export async function makeRequestOptions(
2931
stream?: boolean;
3032
},
3133
options?: Options & {
32-
/** When a model can be used for multiple tasks, and we want to run a non-default task */
33-
forceTask?: string | InferenceTask;
3434
/** To load default model if needed */
3535
taskHint?: InferenceTask;
3636
chatCompletion?: boolean;
@@ -40,14 +40,11 @@ export async function makeRequestOptions(
4040
let otherArgs = remainingArgs;
4141
const provider = maybeProvider ?? "hf-inference";
4242

43-
const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {};
43+
const { includeCredentials, taskHint, chatCompletion } = options ?? {};
4444

4545
if (endpointUrl && provider !== "hf-inference") {
4646
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
4747
}
48-
if (forceTask && provider !== "hf-inference") {
49-
throw new Error(`Cannot use forceTask with a third-party provider.`);
50-
}
5148
if (maybeModel && isUrl(maybeModel)) {
5249
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
5350
}
@@ -78,16 +75,20 @@ export async function makeRequestOptions(
7875
: makeUrl({
7976
authMethod,
8077
chatCompletion: chatCompletion ?? false,
81-
forceTask,
8278
model,
8379
provider: provider ?? "hf-inference",
8480
taskHint,
8581
});
8682

8783
const headers: Record<string, string> = {};
8884
if (accessToken) {
89-
headers["Authorization"] =
90-
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
85+
if (provider === "fal-ai" && authMethod === "provider-key") {
86+
headers["Authorization"] = `Key ${accessToken}`;
87+
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
88+
headers["X-Key"] = accessToken;
89+
} else {
90+
headers["Authorization"] = `Bearer ${accessToken}`;
91+
}
9192
}
9293

9394
// e.g. @huggingface/inference/3.1.3
@@ -149,14 +150,19 @@ function makeUrl(params: {
149150
model: string;
150151
provider: InferenceProvider;
151152
taskHint: InferenceTask | undefined;
152-
forceTask?: string | InferenceTask;
153153
}): string {
154154
if (params.authMethod === "none" && params.provider !== "hf-inference") {
155155
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
156156
}
157157

158158
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
159159
switch (params.provider) {
160+
case "black-forest-labs": {
161+
const baseUrl = shouldProxy
162+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
163+
: BLACKFORESTLABS_AI_API_BASE_URL;
164+
return `${baseUrl}/${params.model}`;
165+
}
160166
case "fal-ai": {
161167
const baseUrl = shouldProxy
162168
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -216,6 +222,7 @@ function makeUrl(params: {
216222
}
217223
return baseUrl;
218224
}
225+
219226
case "fireworks-ai": {
220227
const baseUrl = shouldProxy
221228
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
@@ -235,15 +242,28 @@ function makeUrl(params: {
235242
}
236243
return `${baseUrl}/v1/chat/completions`;
237244
}
245+
case "novita": {
246+
const baseUrl = shouldProxy
247+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
248+
: NOVITA_API_BASE_URL;
249+
if (params.taskHint === "text-generation") {
250+
if (params.chatCompletion) {
251+
return `${baseUrl}/chat/completions`;
252+
}
253+
return `${baseUrl}/completions`;
254+
}
255+
return baseUrl;
256+
}
238257
default: {
239258
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
240-
const url = params.forceTask
241-
? `${baseUrl}/pipeline/${params.forceTask}/${params.model}`
242-
: `${baseUrl}/models/${params.model}`;
259+
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
260+
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
261+
return `${baseUrl}/pipeline/${params.taskHint}/${params.model}`;
262+
}
243263
if (params.taskHint === "text-generation" && params.chatCompletion) {
244-
return url + `/v1/chat/completions`;
264+
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
245265
}
246-
return url;
266+
return `${baseUrl}/models/${params.model}`;
247267
}
248268
}
249269
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
5+
*
6+
* https://huggingface.co/api/partners/blackforestlabs/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
15+
* and we will tag Black Forest Labs team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/providers/consts.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1616
* Example:
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
19+
"black-forest-labs": {},
1920
"fal-ai": {},
2021
"fireworks-ai": {},
2122
"hf-inference": {},
@@ -24,4 +25,5 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2425
replicate: {},
2526
sambanova: {},
2627
together: {},
28+
novita: {},
2729
};
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Novita model ID here:
5+
*
6+
* https://huggingface.co/api/partners/novita/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Novita and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Novita, please open an issue on the present repo
15+
* and we will tag Novita team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, InferenceProvider, Options } from "../../types";
44
import { omit } from "../../utils/omit";
55
import { request } from "../custom/request";
6+
import { delay } from "../../utils/delay";
67

78
export type TextToImageArgs = BaseArgs & TextToImageInput;
89

@@ -18,6 +19,11 @@ interface HyperbolicTextToImageOutput {
1819
images: Array<{ image: string }>;
1920
}
2021

22+
interface BlackForestLabsResponse {
23+
id: string;
24+
polling_url: string;
25+
}
26+
2127
function getResponseFormatArg(provider: InferenceProvider) {
2228
switch (provider) {
2329
case "fal-ai":
@@ -48,13 +54,20 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
4854
prompt: args.inputs,
4955
};
5056
const res = await request<
51-
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | HyperbolicTextToImageOutput
57+
| TextToImageOutput
58+
| Base64ImageGeneration
59+
| OutputUrlImageGeneration
60+
| BlackForestLabsResponse
61+
| HyperbolicTextToImageOutput
5262
>(payload, {
5363
...options,
5464
taskHint: "text-to-image",
5565
});
5666

5767
if (res && typeof res === "object") {
68+
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
69+
return await pollBflResponse(res.polling_url);
70+
}
5871
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
5972
const image = await fetch(res.images[0].url);
6073
return await image.blob();
@@ -88,3 +101,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
88101
}
89102
return res;
90103
}
104+
105+
async function pollBflResponse(url: string): Promise<Blob> {
106+
const urlObj = new URL(url);
107+
for (let step = 0; step < 5; step++) {
108+
await delay(1000);
109+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
110+
urlObj.searchParams.set("attempt", step.toString(10));
111+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
112+
if (!resp.ok) {
113+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
114+
}
115+
const payload = await resp.json();
116+
if (
117+
typeof payload === "object" &&
118+
payload &&
119+
"status" in payload &&
120+
typeof payload.status === "string" &&
121+
payload.status === "Ready" &&
122+
"result" in payload &&
123+
typeof payload.result === "object" &&
124+
payload.result &&
125+
"sample" in payload.result &&
126+
typeof payload.result.sample === "string"
127+
) {
128+
const image = await fetch(payload.result.sample);
129+
return await image.blob();
130+
}
131+
}
132+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
133+
}

packages/inference/src/tasks/nlp/featureExtraction.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { InferenceOutputError } from "../../lib/InferenceOutputError";
2-
import { getDefaultTask } from "../../lib/getDefaultTask";
32
import type { BaseArgs, Options } from "../../types";
43
import { request } from "../custom/request";
54

@@ -25,12 +24,9 @@ export async function featureExtraction(
2524
args: FeatureExtractionArgs,
2625
options?: Options
2726
): Promise<FeatureExtractionOutput> {
28-
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
29-
3027
const res = await request<FeatureExtractionOutput>(args, {
3128
...options,
3229
taskHint: "feature-extraction",
33-
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
3430
});
3531
let isValidOutput = true;
3632

packages/inference/src/tasks/nlp/sentenceSimilarity.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
22
import { InferenceOutputError } from "../../lib/InferenceOutputError";
3-
import { getDefaultTask } from "../../lib/getDefaultTask";
43
import type { BaseArgs, Options } from "../../types";
54
import { request } from "../custom/request";
65
import { omit } from "../../utils/omit";
@@ -14,11 +13,9 @@ export async function sentenceSimilarity(
1413
args: SentenceSimilarityArgs,
1514
options?: Options
1615
): Promise<SentenceSimilarityOutput> {
17-
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
1816
const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
1917
...options,
2018
taskHint: "sentence-similarity",
21-
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
2219
});
2320

2421
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

packages/inference/src/types.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@ export interface Options {
2929
export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
32+
"black-forest-labs",
3233
"fal-ai",
3334
"fireworks-ai",
35+
"hf-inference",
3436
"hyperbolic",
3537
"nebius",
36-
"hf-inference",
38+
"novita",
3739
"replicate",
3840
"sambanova",
3941
"together",
4042
] as const;
43+
4144
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4245

4346
export interface BaseArgs {

0 commit comments

Comments
 (0)