Skip to content

Commit c96c17b

Browse files
authored
[Inference] get the raw response for textToImage (#1617)
Allow to get the full response for `text-to-image` inference, instead of directly the blob or the url: ```js const res = await textToImage({ inputs: "A dog with a baseball cap" },{ outputType: "json" }); ``` we have the following ouput (depends on the provider) ```json { "images": [ { "url": "data:image/png;base64,iVBORw0KGgoAAAAN......rkJggg==", "width": 1024, "height": 768, "content_type": "image/png" } ], "timings": { "inference": 1.2445615287870169 }, "seed": 982874191, "has_nsfw_concepts": [ false ], "prompt": "A dog with a baseball cap" } ``` needed for (internal huggingface-internal/moon-landing#14421)
1 parent e824d26 commit c96c17b

File tree

11 files changed

+71
-20
lines changed

11 files changed

+71
-20
lines changed

packages/inference/src/providers/black-forest-labs.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
6666
response: BlackForestLabsResponse,
6767
url?: string,
6868
headers?: HeadersInit,
69-
outputType?: "url" | "blob"
70-
): Promise<string | Blob> {
69+
outputType?: "url" | "blob" | "json"
70+
): Promise<string | Blob | Record<string, unknown>> {
7171
const logger = getLogger();
7272
const urlObj = new URL(response.polling_url);
7373
for (let step = 0; step < 5; step++) {
@@ -95,6 +95,9 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper implement
9595
"sample" in payload.result &&
9696
typeof payload.result.sample === "string"
9797
) {
98+
if (outputType === "json") {
99+
return payload.result;
100+
}
98101
if (outputType === "url") {
99102
return payload.result.sample;
100103
}

packages/inference/src/providers/fal-ai.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,12 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
182182
return payload;
183183
}
184184

185-
override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
185+
override async getResponse(
186+
response: FalAITextToImageOutput,
187+
url?: string,
188+
headers?: HeadersInit,
189+
outputType?: "url" | "blob" | "json"
190+
): Promise<string | Blob | Record<string, unknown>> {
186191
if (
187192
typeof response === "object" &&
188193
"images" in response &&
@@ -191,6 +196,9 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
191196
"url" in response.images[0] &&
192197
typeof response.images[0].url === "string"
193198
) {
199+
if (outputType === "json") {
200+
return { ...response };
201+
}
194202
if (outputType === "url") {
195203
return response.images[0].url;
196204
}

packages/inference/src/providers/hf-inference.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,17 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
127127
response: Base64ImageGeneration | OutputUrlImageGeneration,
128128
url?: string,
129129
headers?: HeadersInit,
130-
outputType?: "url" | "blob"
131-
): Promise<string | Blob> {
130+
outputType?: "url" | "blob" | "json"
131+
): Promise<string | Blob | Record<string, unknown>> {
132132
if (!response) {
133133
throw new InferenceClientProviderOutputError(
134134
"Received malformed response from HF-Inference text-to-image API: response is undefined"
135135
);
136136
}
137137
if (typeof response == "object") {
138+
if (outputType === "json") {
139+
return { ...response };
140+
}
138141
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
139142
const base64Data = response.data[0].b64_json;
140143
if (outputType === "url") {
@@ -153,9 +156,9 @@ export class HFInferenceTextToImageTask extends HFInferenceTask implements TextT
153156
}
154157
}
155158
if (response instanceof Blob) {
156-
if (outputType === "url") {
159+
if (outputType === "url" || outputType === "json") {
157160
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
158-
return `data:image/jpeg;base64,${b64}`;
161+
return outputType === "url" ? `data:image/jpeg;base64,${b64}` : { output: `data:image/jpeg;base64,${b64}` };
159162
}
160163
return response;
161164
}

packages/inference/src/providers/hyperbolic.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,18 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper implements Tex
105105
response: HyperbolicTextToImageOutput,
106106
url?: string,
107107
headers?: HeadersInit,
108-
outputType?: "url" | "blob"
109-
): Promise<string | Blob> {
108+
outputType?: "url" | "blob" | "json"
109+
): Promise<string | Blob | Record<string, unknown>> {
110110
if (
111111
typeof response === "object" &&
112112
"images" in response &&
113113
Array.isArray(response.images) &&
114114
response.images[0] &&
115115
typeof response.images[0].image === "string"
116116
) {
117+
if (outputType === "json") {
118+
return { ...response };
119+
}
117120
if (outputType === "url") {
118121
return `data:image/jpeg;base64,${response.images[0].image}`;
119122
}

packages/inference/src/providers/nebius.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
116116
response: NebiusBase64ImageGeneration,
117117
url?: string,
118118
headers?: HeadersInit,
119-
outputType?: "url" | "blob"
120-
): Promise<string | Blob> {
119+
outputType?: "url" | "blob" | "json"
120+
): Promise<string | Blob | Record<string, unknown>> {
121121
if (
122122
typeof response === "object" &&
123123
"data" in response &&
@@ -126,6 +126,9 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
126126
"b64_json" in response.data[0] &&
127127
typeof response.data[0].b64_json === "string"
128128
) {
129+
if (outputType === "json") {
130+
return { ...response };
131+
}
129132
const base64Data = response.data[0].b64_json;
130133
if (outputType === "url") {
131134
return `data:image/jpeg;base64,${base64Data}`;

packages/inference/src/providers/nscale.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
5757
response: NscaleCloudBase64ImageGeneration,
5858
url?: string,
5959
headers?: HeadersInit,
60-
outputType?: "url" | "blob"
61-
): Promise<string | Blob> {
60+
outputType?: "url" | "blob" | "json"
61+
): Promise<string | Blob | Record<string, unknown>> {
6262
if (
6363
typeof response === "object" &&
6464
"data" in response &&
@@ -67,6 +67,9 @@ export class NscaleTextToImageTask extends TaskProviderHelper implements TextToI
6767
"b64_json" in response.data[0] &&
6868
typeof response.data[0].b64_json === "string"
6969
) {
70+
if (outputType === "json") {
71+
return { ...response };
72+
}
7073
const base64Data = response.data[0].b64_json;
7174
if (outputType === "url") {
7275
return `data:image/jpeg;base64,${base64Data}`;

packages/inference/src/providers/providerHelper.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ export interface TextToImageTaskHelper {
137137
response: unknown,
138138
url?: string,
139139
headers?: HeadersInit,
140-
outputType?: "url" | "blob"
141-
): Promise<string | Blob>;
140+
outputType?: "url" | "blob" | "json"
141+
): Promise<string | Blob | Record<string, unknown>>;
142142
preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
143143
}
144144

packages/inference/src/providers/replicate.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
8888
res: ReplicateOutput | Blob,
8989
url?: string,
9090
headers?: Record<string, string>,
91-
outputType?: "url" | "blob"
92-
): Promise<string | Blob> {
91+
outputType?: "url" | "blob" | "json"
92+
): Promise<string | Blob | Record<string, unknown>> {
9393
void url;
9494
void headers;
9595
if (
@@ -99,6 +99,9 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
9999
res.output.length > 0 &&
100100
typeof res.output[0] === "string"
101101
) {
102+
if (outputType === "json") {
103+
return { ...res };
104+
}
102105
if (outputType === "url") {
103106
return res.output[0];
104107
}

packages/inference/src/providers/together.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ export class TogetherTextToImageTask extends TaskProviderHelper implements TextT
114114
};
115115
}
116116

117-
async getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): Promise<string | Blob> {
117+
async getResponse(
118+
response: TogetherBase64ImageGeneration,
119+
url?: string,
120+
headers?: HeadersInit,
121+
outputType?: "url" | "blob" | "json"
122+
): Promise<string | Blob | Record<string, unknown>> {
118123
if (
119124
typeof response === "object" &&
120125
"data" in response &&
@@ -123,6 +128,9 @@ export class TogetherTextToImageTask extends TaskProviderHelper implements TextT
123128
"b64_json" in response.data[0] &&
124129
typeof response.data[0].b64_json === "string"
125130
) {
131+
if (outputType === "json") {
132+
return { ...response };
133+
}
126134
const base64Data = response.data[0].b64_json;
127135
if (outputType === "url") {
128136
return `data:image/jpeg;base64,${base64Data}`;

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { innerRequest } from "../../utils/request.js";
88
export type TextToImageArgs = BaseArgs & TextToImageInput;
99

1010
interface TextToImageOptions extends Options {
11-
outputType?: "url" | "blob";
11+
outputType?: "url" | "blob" | "json";
1212
}
1313

1414
/**
@@ -23,7 +23,14 @@ export async function textToImage(
2323
args: TextToImageArgs,
2424
options?: TextToImageOptions & { outputType?: undefined | "blob" }
2525
): Promise<Blob>;
26-
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
26+
export async function textToImage(
27+
args: TextToImageArgs,
28+
options?: TextToImageOptions & { outputType?: undefined | "json" }
29+
): Promise<Record<string, unknown>>;
30+
export async function textToImage(
31+
args: TextToImageArgs,
32+
options?: TextToImageOptions
33+
): Promise<Blob | string | Record<string, unknown>> {
2734
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2835
const providerHelper = getProviderHelper(provider, "text-to-image");
2936
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {

0 commit comments

Comments
 (0)