diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 56c15ebb51..b488e42d3c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,6 +48,7 @@ jobs: HF_TOGETHER_KEY: dummy HF_NOVITA_KEY: dummy HF_FIREWORKS_KEY: dummy + HF_BLACK_FOREST_LABS_KEY: dummy browser: runs-on: ubuntu-latest @@ -91,6 +92,7 @@ jobs: HF_TOGETHER_KEY: dummy HF_NOVITA_KEY: dummy HF_FIREWORKS_KEY: dummy + HF_BLACK_FOREST_LABS_KEY: dummy e2e: runs-on: ubuntu-latest @@ -161,3 +163,4 @@ jobs: HF_TOGETHER_KEY: dummy HF_NOVITA_KEY: dummy HF_FIREWORKS_KEY: dummy + HF_BLACK_FOREST_LABS_KEY: dummy diff --git a/packages/inference/README.md b/packages/inference/README.md index b5168f2bc0..d537f81254 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -54,6 +54,7 @@ Currently, we support the following providers: - [Replicate](https://replicate.com) - [Sambanova](https://sambanova.ai) - [Together](https://together.xyz) +- [Blackforestlabs](https://blackforestlabs.ai) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 3bf20a7d15..bbc2368eca 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -6,6 +6,7 @@ import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova"; import { TOGETHER_API_BASE_URL } from "../providers/together"; import { NOVITA_API_BASE_URL } from "../providers/novita"; import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai"; +import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs"; import type { InferenceProvider } from "../types"; import type { InferenceTask, Options, RequestArgs } from "../types"; import { isUrl } from "./isUrl"; @@ -80,8 +81,13 @@ export async function makeRequestOptions( const headers: Record = {}; if (accessToken) { - headers["Authorization"] = - provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`; + if (provider === "fal-ai" && authMethod === "provider-key") { + headers["Authorization"] = `Key ${accessToken}`; + } else if (provider === "black-forest-labs" && authMethod === "provider-key") { + headers["X-Key"] = accessToken; + } else { + headers["Authorization"] = `Bearer ${accessToken}`; + } } // e.g. @huggingface/inference/3.1.3 @@ -148,6 +154,12 @@ function makeUrl(params: { const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key"; switch (params.provider) { + case "black-forest-labs": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : BLACKFORESTLABS_AI_API_BASE_URL; + return `${baseUrl}/${params.model}`; + } case "fal-ai": { const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) diff --git a/packages/inference/src/providers/black-forest-labs.ts b/packages/inference/src/providers/black-forest-labs.ts new file mode 100644 index 0000000000..fe5f140149 --- /dev/null +++ b/packages/inference/src/providers/black-forest-labs.ts @@ -0,0 +1,18 @@ +export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1"; + +/** + * See the registered mapping of HF model ID => Black Forest Labs model ID here: + * + * https://huggingface.co/api/partners/blackforestlabs/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo + * and we will tag Black Forest Labs team members. + * + * Thanks! + */ diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 1c8bb27eb8..2ad19bb0fd 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record(payload, { + const res = await request< + TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse + >(payload, { ...options, taskHint: "text-to-image", }); if (res && typeof res === "object") { + if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") { + return await pollBflResponse(res.polling_url); + } if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) { const image = await fetch(res.images[0].url); return await image.blob(); @@ -72,3 +82,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro } return res; } + +async function pollBflResponse(url: string): Promise { + const urlObj = new URL(url); + for (let step = 0; step < 5; step++) { + await delay(1000); + console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`); + urlObj.searchParams.set("attempt", step.toString(10)); + const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } }); + if (!resp.ok) { + throw new InferenceOutputError("Failed to fetch result from black forest labs API"); + } + const payload = await resp.json(); + if ( + typeof payload === "object" && + payload && + "status" in payload && + typeof payload.status === "string" && + payload.status === "Ready" && + "result" in payload && + typeof payload.result === "object" && + payload.result && + "sample" in payload.result && + typeof payload.result.sample === "string" + ) { + const image = await fetch(payload.result.sample); + return await image.blob(); + } + } + throw new InferenceOutputError("Failed to fetch result from black forest labs API"); +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 34cbbddf9a..d6eed1fc77 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -29,14 +29,15 @@ export interface Options { export type InferenceTask = Exclude; export const INFERENCE_PROVIDERS = [ + "black-forest-labs", "fal-ai", "fireworks-ai", - "nebius", "hf-inference", + "nebius", + "novita", "replicate", "sambanova", "together", - "novita", ] as const; export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; diff --git a/packages/inference/src/utils/delay.ts b/packages/inference/src/utils/delay.ts new file mode 100644 index 0000000000..e35acce902 --- /dev/null +++ b/packages/inference/src/utils/delay.ts @@ -0,0 +1,5 @@ +export function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(() => resolve(), ms); + }); +} diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index 991d5463c8..0fe8931b2a 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -2,7 +2,7 @@ import { assert, describe, expect, it } from "vitest"; import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; -import { chatCompletion, HfInference } from "../src"; +import { chatCompletion, HfInference, textToImage } from "../src"; import { textToVideo } from "../src/tasks/cv/textToVideo"; import { readTestFile } from "./test-files"; import "./vcr"; @@ -1214,4 +1214,30 @@ describe.concurrent("HfInference", () => { }, TIMEOUT ); + describe.concurrent( + "Black Forest Labs", + () => { + HARDCODED_MODEL_ID_MAPPING["black-forest-labs"] = { + "black-forest-labs/FLUX.1-dev": "flux-dev", + // "black-forest-labs/FLUX.1-schnell": "flux-pro", + }; + + it("textToImage", async () => { + const res = await textToImage({ + model: "black-forest-labs/FLUX.1-dev", + provider: "black-forest-labs", + accessToken: env.HF_BLACK_FOREST_LABS_KEY, + inputs: "A raccoon driving a truck", + parameters: { + height: 256, + width: 256, + num_inference_steps: 4, + seed: 8817, + }, + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + TIMEOUT + ); }); diff --git a/packages/inference/test/tapes.json b/packages/inference/test/tapes.json index 79593a964d..f150893417 100644 --- a/packages/inference/test/tapes.json +++ b/packages/inference/test/tapes.json @@ -6994,5 +6994,80 @@ "transfer-encoding": "chunked" } } + }, + "b320223c78e20541a47c961d89d24f507b0b0257224d91cd05744c93f2d67d2c": { + "url": "https://api.us1.bfl.ai/v1/flux-dev", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"height\":256,\"width\":256,\"num_inference_steps\":4,\"seed\":8817,\"prompt\":\"A raccoon driving a truck\"}" + }, + "response": { + "body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"polling_url\":\"https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160\"}", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "application/json", + "strict-transport-security": "max-age=31536000; includeSubDomains" + } + } + }, + "23eefbade142f7a1e33d50dd6bfaf56e7b959689f7990025db1b353214890a03": { + "url": "https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160&attempt=0", + "init": { + "headers": { + "Content-Type": "application/json" + } + }, + "response": { + "body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"status\":\"Pending\",\"result\":null,\"progress\":0.6}", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "application/json", + "retry-after": "1", + "strict-transport-security": "max-age=31536000; includeSubDomains" + } + } + }, + "5803254b4092ae6ac445292c617480002607bb30cce9ba8dc37ce9bb2754f94b": { + "url": "https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160&attempt=1", + "init": { + "headers": { + "Content-Type": "application/json" + } + }, + "response": { + "body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"status\":\"Ready\",\"result\":{\"sample\":\"https://delivery-us1.bfl.ai/results/aa7ab8da64b946dca070d455854a0c3e/sample.jpeg?se=2025-02-13T16%3A12%3A37Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=u0wzLXKBr8dCMnk9US51zQs7Ma/x/l0lEJvEM3pMUrA%3D\",\"prompt\":\"A raccoon driving a truck\",\"seed\":8817,\"start_time\":1739462555.336884,\"end_time\":1739462557.9051642,\"duration\":2.5682802200317383},\"progress\":null}", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "application/json", + "retry-after": "1", + "strict-transport-security": "max-age=31536000; includeSubDomains" + } + } + }, + "548ead8522302cb1123833c27e33d193a8fd619633271414bd2d84e1a71469f0": { + "url": "https://delivery-us1.bfl.ai/results/aa7ab8da64b946dca070d455854a0c3e/sample.jpeg?se=2025-02-13T16%3A12%3A37Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=u0wzLXKBr8dCMnk9US51zQs7Ma/x/l0lEJvEM3pMUrA%3D", + "init": {}, + "response": { + "body": "", + "status": 200, + "statusText": "OK", + "headers": { + "accept-ranges": "bytes", + "connection": "keep-alive", + "content-md5": "RRVPlYWCAb46mkS5lSs7RQ==", + "content-type": "image/jpeg", + "etag": "\"0x8DD4C47D65702E7\"", + "last-modified": "Thu, 13 Feb 2025 16:02:37 GMT" + } + } } } \ No newline at end of file diff --git a/packages/inference/test/vcr.ts b/packages/inference/test/vcr.ts index 2559f523b1..79a548af04 100644 --- a/packages/inference/test/vcr.ts +++ b/packages/inference/test/vcr.ts @@ -181,7 +181,7 @@ async function vcr( const tape: Tape = { url, init: { - headers: init.headers && omit(init.headers as Record, ["Authorization", "User-Agent"]), + headers: init.headers && omit(init.headers as Record, ["Authorization", "User-Agent", "X-Key"]), method: init.method, body: typeof init.body === "string" && init.body.length < 1_000 ? init.body : undefined, },