Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -89,6 +90,7 @@ jobs:
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -158,3 +160,4 @@ jobs:
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy
1 change: 1 addition & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Currently, we support the following providers:
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)
- [Blackforestlabs](https://blackforestlabs.ai)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand Down
16 changes: 14 additions & 2 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand Down Expand Up @@ -85,8 +86,13 @@ export async function makeRequestOptions(

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] =
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
if (provider === "fal-ai" && authMethod === "provider-key") {
headers["Authorization"] = `Key ${accessToken}`;
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
headers["X-Key"] = accessToken;
} else {
headers["Authorization"] = `Bearer ${accessToken}`;
}
Comment on lines +84 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

// e.g. @huggingface/inference/3.1.3
Expand Down Expand Up @@ -154,6 +160,12 @@ function makeUrl(params: {

const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
switch (params.provider) {
case "black-forest-labs": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: BLACKFORESTLABS_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "fal-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";

/**
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
*
* https://huggingface.co/api/partners/blackforestlabs/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
* and we will tag Black Forest Labs team members.
*
* Thanks!
*/
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
"black-forest-labs": {},
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
Expand Down
53 changes: 47 additions & 6 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, InferenceProvider, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";
import { delay } from "../../utils/delay";
import { randomUUID } from "crypto";

export type TextToImageArgs = BaseArgs & TextToImageInput;

Expand All @@ -14,6 +16,10 @@ interface Base64ImageGeneration {
interface OutputUrlImageGeneration {
output: string[];
}
interface BlackForestLabsResponse {
id: string;
polling_url: string;
}

function getResponseFormatArg(provider: InferenceProvider) {
switch (provider) {
Expand All @@ -39,17 +45,22 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
!args.provider || args.provider === "hf-inference" || args.provider === "sambanova"
? args
: {
...omit(args, ["inputs", "parameters"]),
...args.parameters,
...getResponseFormatArg(args.provider),
prompt: args.inputs,
};
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
...omit(args, ["inputs", "parameters"]),
...args.parameters,
...getResponseFormatArg(args.provider),
prompt: args.inputs,
};
const res = await request<
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
>(payload, {
...options,
taskHint: "text-to-image",
});

if (res && typeof res === "object") {
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
return await pollBflResponse(res.polling_url);
}
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
const image = await fetch(res.images[0].url);
return await image.blob();
Expand All @@ -72,3 +83,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
}
return res;
}

async function pollBflResponse(url: string): Promise<Blob> {
const urlObj = new URL(url);
for (let step = 0; step < 5; step++) {
await delay(1000);
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
urlObj.searchParams.set("attempt", step.toString(10));
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
if (!resp.ok) {
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
const payload = await resp.json();
if (
typeof payload === "object" &&
payload &&
"status" in payload &&
typeof payload.status === "string" &&
payload.status === "Ready" &&
"result" in payload &&
typeof payload.result === "object" &&
payload.result &&
"sample" in payload.result &&
typeof payload.result.sample === "string"
) {
const image = await fetch(payload.result.sample);
return await image.blob();
}
}
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const INFERENCE_PROVIDERS = [
"replicate",
"sambanova",
"together",
"black-forest-labs",
] as const;
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/utils/delay.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export function delay(ms: number): Promise<void> {
return new Promise((resolve) => {
setTimeout(() => resolve(), ms);
});
}
29 changes: 28 additions & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { assert, describe, expect, it } from "vitest";

import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import { chatCompletion, HfInference } from "../src";
import { chatCompletion, HfInference, textToImage } from "../src";
import { textToVideo } from "../src/tasks/cv/textToVideo";
import { readTestFile } from "./test-files";
import "./vcr";
Expand Down Expand Up @@ -1175,4 +1175,31 @@ describe.concurrent("HfInference", () => {
},
TIMEOUT
);

describe.concurrent(
"Black Forest Labs",
() => {
HARDCODED_MODEL_ID_MAPPING["black-forest-labs"] = {
"black-forest-labs/FLUX.1-dev": "flux-dev",
// "black-forest-labs/FLUX.1-schnell": "flux-pro",
};

it("textToImage", async () => {
const res = await textToImage({
model: "black-forest-labs/FLUX.1-dev",
provider: "black-forest-labs",
accessToken: env.HF_BLACK_FOREST_LABS_KEY,
inputs: "A raccoon driving a truck",
parameters: {
height: 256,
width: 256,
num_inference_steps: 4,
seed: 8817,
},
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);
});
75 changes: 75 additions & 0 deletions packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -7025,5 +7025,80 @@
"content-type": "image/jpeg"
}
}
},
"b320223c78e20541a47c961d89d24f507b0b0257224d91cd05744c93f2d67d2c": {
"url": "https://api.us1.bfl.ai/v1/flux-dev",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"height\":256,\"width\":256,\"num_inference_steps\":4,\"seed\":8817,\"prompt\":\"A raccoon driving a truck\"}"
},
"response": {
"body": "{\"id\":\"9cd5d992-3184-4e16-bd72-e000f7b1182a\",\"polling_url\":\"https://api.us1.bfl.ai/v1/get_result?id=9cd5d992-3184-4e16-bd72-e000f7b1182a\"}",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "application/json",
"strict-transport-security": "max-age=31536000; includeSubDomains"
}
}
},
"9b83ba327d80e12141b3d61f08ae13af8582aa52cc5e8e9f5f860749bea3b4c0": {
"url": "https://api.us1.bfl.ai/v1/get_result?id=9cd5d992-3184-4e16-bd72-e000f7b1182a&attempt=0",
"init": {
"headers": {
"Content-Type": "application/json"
}
},
"response": {
"body": "{\"id\":\"9cd5d992-3184-4e16-bd72-e000f7b1182a\",\"status\":\"Pending\",\"result\":null,\"progress\":0.7}",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "application/json",
"retry-after": "1",
"strict-transport-security": "max-age=31536000; includeSubDomains"
}
}
},
"0c3cc43124071a2bf53a619a9bce037c14e1ed790e8cce40db45389dbaee6b4e": {
"url": "https://api.us1.bfl.ai/v1/get_result?id=9cd5d992-3184-4e16-bd72-e000f7b1182a&attempt=1",
"init": {
"headers": {
"Content-Type": "application/json"
}
},
"response": {
"body": "{\"id\":\"9cd5d992-3184-4e16-bd72-e000f7b1182a\",\"status\":\"Ready\",\"result\":{\"sample\":\"https://delivery-us1.bfl.ai/results/5965f3f68d50412c9d62c08480d7cd75/sample.jpeg?se=2025-02-12T15%3A25%3A59Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=ZQfYFNXVQLTxwXJb/2HPPb6fl5ITyJukYiYsBY15iBs%3D\",\"prompt\":\"A raccoon driving a truck\",\"seed\":8817,\"start_time\":1739373357.096922,\"end_time\":1739373359.4733694,\"duration\":2.3764474391937256},\"progress\":null}",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "application/json",
"retry-after": "1",
"strict-transport-security": "max-age=31536000; includeSubDomains"
}
}
},
"999dc957255935037c9b76833430b679add51c628940a97289dd3e6f0944e60f": {
"url": "https://delivery-us1.bfl.ai/results/5965f3f68d50412c9d62c08480d7cd75/sample.jpeg?se=2025-02-12T15%3A25%3A59Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=ZQfYFNXVQLTxwXJb/2HPPb6fl5ITyJukYiYsBY15iBs%3D",
"init": {},
"response": {
"body": "",
"status": 200,
"statusText": "OK",
"headers": {
"accept-ranges": "bytes",
"connection": "keep-alive",
"content-md5": "ZzxARdjL4KeKFdZDh6Wk3A==",
"content-type": "image/jpeg",
"etag": "\"0x8DD4B7827EE36A3\"",
"last-modified": "Wed, 12 Feb 2025 15:15:59 GMT"
}
}
}
}
2 changes: 1 addition & 1 deletion packages/inference/test/vcr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async function vcr(
const tape: Tape = {
url,
init: {
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent"]),
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent", "X-Key"]),
method: init.method,
body: typeof init.body === "string" && init.body.length < 1_000 ? init.body : undefined,
},
Expand Down