Skip to content

Commit d96a18e

Browse files
committed
textToImage should work too
1 parent 5629b86 commit d96a18e

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ export async function makeRequestOptions(
118118
case "sambanova":
119119
return SAMBANOVA_API_BASE_URL;
120120
case "together":
121+
if (taskHint === "text-to-image") {
122+
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
123+
}
121124
return TOGETHER_API_BASE_URL;
122125
default:
123126
break;

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ export type TextToImageArgs = BaseArgs & {
88
*/
99
inputs: string;
1010

11+
/**
12+
* Same param but for external providers like Together
13+
*/
14+
prompt?: string;
15+
response_format?: "base64";
16+
1117
parameters?: {
1218
/**
1319
* An optional negative prompt for the image generation
@@ -32,17 +38,35 @@ export type TextToImageArgs = BaseArgs & {
3238
};
3339
};
3440

41+
interface Base64ImageGeneration {
42+
id: string;
43+
model: string;
44+
data: Array<{
45+
b64_json: string;
46+
}>;
47+
}
3548
export type TextToImageOutput = Blob;
3649

3750
/**
3851
* This task reads some text input and outputs an image.
3952
* Recommended model: stabilityai/stable-diffusion-2
4053
*/
4154
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
42-
const res = await request<TextToImageOutput>(args, {
55+
if (args.provider === "together") {
56+
args.prompt = args.inputs;
57+
args.inputs = "";
58+
args.response_format = "base64";
59+
}
60+
const res = await request<TextToImageOutput | Base64ImageGeneration>(args, {
4361
...options,
4462
taskHint: "text-to-image",
4563
});
64+
if (res && typeof res === "object" && Array.isArray(res.data) && res.data[0].b64_json) {
65+
const base64Data = res.data[0].b64_json;
66+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
67+
const blob = await base64Response.blob();
68+
return blob;
69+
}
4670
const isValidOutput = res && res instanceof Blob;
4771
if (!isValidOutput) {
4872
throw new InferenceOutputError("Expected Blob");

packages/inference/test/HfInference.spec.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,16 @@ describe.concurrent(
816816
}
817817
expect(out).toContain("2");
818818
});
819+
820+
it("textToImage together", async () => {
821+
const hf = new HfInference(env.TOGETHER_KEY);
822+
const res = await hf.textToImage({
823+
model: "stabilityai/stable-diffusion-xl-base-1.0",
824+
provider: "together",
825+
inputs: "award winning high resolution photo of a giant tortoise",
826+
});
827+
expect(res).toBeInstanceOf(Blob);
828+
});
819829
},
820830
TIMEOUT
821831
);

0 commit comments

Comments
 (0)