Skip to content

Commit 081a6ab

Browse files
authored
[Inference] text-to-image: option to return a URL instead of Blob (#1207)
Context (internal): https://huggingface.slack.com/archives/C07KX53FZTK/p1739562014823989?thread_ts=1739561786.364929&cid=C07KX53FZTK
1 parent a36e81f commit 081a6ab

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ interface BlackForestLabsResponse {
2424
polling_url: string;
2525
}
2626

27+
interface TextToImageOptions extends Options {
28+
outputType?: "url" | "blob";
29+
}
30+
2731
function getResponseFormatArg(provider: InferenceProvider) {
2832
switch (provider) {
2933
case "fal-ai":
@@ -43,7 +47,15 @@ function getResponseFormatArg(provider: InferenceProvider) {
4347
* This task reads some text input and outputs an image.
4448
* Recommended model: stabilityai/stable-diffusion-2
4549
*/
46-
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
50+
export async function textToImage(
51+
args: TextToImageArgs,
52+
options?: TextToImageOptions & { outputType: "url" }
53+
): Promise<string>;
54+
export async function textToImage(
55+
args: TextToImageArgs,
56+
options?: TextToImageOptions & { outputType?: undefined | "blob" }
57+
): Promise<Blob>;
58+
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
4759
const payload =
4860
!args.provider || args.provider === "hf-inference" || args.provider === "sambanova"
4961
? args
@@ -66,11 +78,18 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
6678

6779
if (res && typeof res === "object") {
6880
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
81+
if (options?.outputType === "url") {
82+
return res.polling_url;
83+
}
6984
return await pollBflResponse(res.polling_url);
7085
}
7186
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
72-
const image = await fetch(res.images[0].url);
73-
return await image.blob();
87+
if (options?.outputType === "url") {
88+
return res.images[0].url;
89+
} else {
90+
const image = await fetch(res.images[0].url);
91+
return await image.blob();
92+
}
7493
}
7594
if (
7695
args.provider === "hyperbolic" &&
@@ -79,17 +98,24 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
7998
res.images[0] &&
8099
typeof res.images[0].image === "string"
81100
) {
101+
if (options?.outputType === "url") {
102+
return `data:image/jpeg;base64,${res.images[0].image}`;
103+
}
82104
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
83-
const blob = await base64Response.blob();
84-
return blob;
105+
return await base64Response.blob();
85106
}
86107
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
87108
const base64Data = res.data[0].b64_json;
109+
if (options?.outputType === "url") {
110+
return `data:image/jpeg;base64,${base64Data}`;
111+
}
88112
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
89-
const blob = await base64Response.blob();
90-
return blob;
113+
return await base64Response.blob();
91114
}
92115
if ("output" in res && Array.isArray(res.output)) {
116+
if (options?.outputType === "url") {
117+
return res.output[0];
118+
}
93119
const urlResponse = await fetch(res.output[0]);
94120
const blob = await urlResponse.blob();
95121
return blob;
@@ -99,6 +125,10 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
99125
if (!isValidOutput) {
100126
throw new InferenceOutputError("Expected Blob");
101127
}
128+
if (options?.outputType === "url") {
129+
const b64 = await res.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
130+
return `data:image/jpeg;base64,${b64}`;
131+
}
102132
return res;
103133
}
104134

packages/inference/test/HfInference.spec.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { textToVideo } from "../src/tasks/cv/textToVideo";
88
import { readTestFile } from "./test-files";
99
import "./vcr";
1010
import { HARDCODED_MODEL_ID_MAPPING } from "../src/providers/consts";
11+
import { isUrl } from "../src/lib/isUrl";
1112

1213
const TIMEOUT = 60000 * 3;
1314
const env = import.meta.env;
@@ -1326,6 +1327,26 @@ describe.concurrent("HfInference", () => {
13261327
});
13271328
expect(res).toBeInstanceOf(Blob);
13281329
});
1330+
1331+
it("textToImage URL", async () => {
1332+
const res = await textToImage(
1333+
{
1334+
model: "black-forest-labs/FLUX.1-dev",
1335+
provider: "black-forest-labs",
1336+
accessToken: env.HF_BLACK_FOREST_LABS_KEY,
1337+
inputs: "A raccoon driving a truck",
1338+
parameters: {
1339+
height: 256,
1340+
width: 256,
1341+
num_inference_steps: 4,
1342+
seed: 8817,
1343+
},
1344+
},
1345+
{ outputType: "url" }
1346+
);
1347+
expect(res).toBeTypeOf("string");
1348+
expect(isUrl(res)).toBeTruthy();
1349+
});
13291350
},
13301351
TIMEOUT
13311352
);

0 commit comments

Comments
 (0)