Skip to content

Commit a3f6ecd

Browse files
authored
[Tasks] add text-to-video JSON spec + Inference client (#1099)
# TL;DR - Add input / output JSON schema specs for the `text-to-video` task - Use generated Typescript interfaces to add support for this in the inference client - Add a few text-to-video models to the list of supported models for Fal.ai and Replicate
1 parent 1b47f90 commit a3f6ecd

File tree

11 files changed

+766
-103
lines changed

11 files changed

+766
-103
lines changed

packages/inference/src/providers/fal-ai.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,8 @@ export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
2020
"automatic-speech-recognition": {
2121
"openai/whisper-large-v3": "fal-ai/whisper",
2222
},
23+
"text-to-video": {
24+
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
25+
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
26+
},
2327
};

packages/inference/src/providers/replicate.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
1313
"text-to-speech": {
1414
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
1515
},
16+
"text-to-video": {
17+
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
18+
},
1619
};

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { InferenceOutputError } from "../../lib/InferenceOutputError";
22
import type { BaseArgs, Options } from "../../types";
3+
import { omit } from "../../utils/omit";
34
import { request } from "../custom/request";
45

56
export type TextToImageArgs = BaseArgs & {
@@ -57,15 +58,16 @@ interface OutputUrlImageGeneration {
5758
* Recommended model: stabilityai/stable-diffusion-2
5859
*/
5960
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
60-
if (args.provider === "together" || args.provider === "fal-ai") {
61-
args.prompt = args.inputs;
62-
delete (args as unknown as { inputs: unknown }).inputs;
63-
args.response_format = "base64";
64-
} else if (args.provider === "replicate") {
65-
args.prompt = args.inputs;
66-
delete (args as unknown as { inputs: unknown }).inputs;
67-
}
68-
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
61+
const payload =
62+
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
63+
? {
64+
...omit(args, ["inputs", "parameters"]),
65+
...args.parameters,
66+
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
67+
prompt: args.inputs,
68+
}
69+
: args;
70+
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
6971
...options,
7072
taskHint: "text-to-image",
7173
});
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import type { BaseArgs, InferenceProvider, Options } from "../../types";
2+
import type { TextToVideoInput } from "@huggingface/tasks";
3+
import { request } from "../custom/request";
4+
import { omit } from "../../utils/omit";
5+
import { isUrl } from "../../lib/isUrl";
6+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
7+
import { typedInclude } from "../../utils/typedInclude";
8+
9+
export type TextToVideoArgs = BaseArgs & TextToVideoInput;
10+
11+
export type TextToVideoOutput = Blob;
12+
13+
interface FalAiOutput {
14+
video: {
15+
url: string;
16+
};
17+
}
18+
19+
interface ReplicateOutput {
20+
output: string;
21+
}
22+
23+
const SUPPORTED_PROVIDERS = ["fal-ai", "replicate"] as const satisfies readonly InferenceProvider[];
24+
25+
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
26+
if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
27+
throw new Error(
28+
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
29+
);
30+
}
31+
32+
const payload =
33+
args.provider === "fal-ai" || args.provider === "replicate"
34+
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
35+
: args;
36+
const res = await request<FalAiOutput | ReplicateOutput>(payload, {
37+
...options,
38+
taskHint: "text-to-video",
39+
});
40+
41+
if (args.provider === "fal-ai") {
42+
const isValidOutput =
43+
typeof res === "object" &&
44+
!!res &&
45+
"video" in res &&
46+
typeof res.video === "object" &&
47+
!!res.video &&
48+
"url" in res.video &&
49+
typeof res.video.url === "string" &&
50+
isUrl(res.video.url);
51+
if (!isValidOutput) {
52+
throw new InferenceOutputError("Expected { video: { url: string } }");
53+
}
54+
const urlResponse = await fetch(res.video.url);
55+
return await urlResponse.blob();
56+
} else {
57+
/// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
58+
/// https://replicate.com/docs/topics/predictions/create-a-prediction
59+
const isValidOutput =
60+
typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
61+
if (!isValidOutput) {
62+
throw new InferenceOutputError("Expected { output: string }");
63+
}
64+
const urlResponse = await fetch(res.output);
65+
return await urlResponse.blob();
66+
}
67+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ export interface BaseArgs {
8484
}
8585

8686
export type RequestArgs = BaseArgs &
87-
({ data: Blob | ArrayBuffer } | { inputs: unknown } | ChatCompletionInput) & {
87+
({ data: Blob | ArrayBuffer } | { inputs: unknown } | { prompt: string } | ChatCompletionInput) & {
8888
parameters?: Record<string, unknown>;
8989
accessToken?: string;
9090
};

packages/inference/test/HfInference.spec.ts

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
55
import { chatCompletion, FAL_AI_SUPPORTED_MODEL_IDS, HfInference } from "../src";
66
import "./vcr";
77
import { readTestFile } from "./test-files";
8+
import { textToVideo } from "../src/tasks/cv/textToVideo";
89

910
const TIMEOUT = 60000 * 3;
1011
const env = import.meta.env;
@@ -47,7 +48,7 @@ describe.concurrent("HfInference", () => {
4748
);
4849
});
4950

50-
it("works without model", async () => {
51+
it.skip("works without model", async () => {
5152
expect(
5253
await hf.fillMask({
5354
inputs: "[MASK] world!",
@@ -799,6 +800,35 @@ describe.concurrent("HfInference", () => {
799800
});
800801
});
801802
}
803+
804+
it("textToVideo - genmo/mochi-1-preview", async () => {
805+
const res = await textToVideo({
806+
model: "genmo/mochi-1-preview",
807+
inputs: "A running dog",
808+
parameters: {
809+
seed: 176,
810+
},
811+
provider: "fal-ai",
812+
accessToken: env.HF_FAL_KEY,
813+
});
814+
expect(res).toBeInstanceOf(Blob);
815+
});
816+
817+
it("textToVideo - HunyuanVideo", async () => {
818+
const res = await textToVideo({
819+
model: "genmo/mochi-1-preview",
820+
inputs: "A running dog",
821+
parameters: {
822+
seed: 176,
823+
num_inference_steps: 2,
824+
num_frames: 85,
825+
resolution: "480p",
826+
},
827+
provider: "fal-ai",
828+
accessToken: env.HF_FAL_KEY,
829+
});
830+
expect(res).toBeInstanceOf(Blob);
831+
});
802832
},
803833
TIMEOUT
804834
);
@@ -844,6 +874,22 @@ describe.concurrent("HfInference", () => {
844874

845875
expect(res).toBeInstanceOf(Blob);
846876
});
877+
878+
it("textToVideo Mochi", async () => {
879+
const res = await textToVideo({
880+
accessToken: env.HF_REPLICATE_KEY,
881+
model: "genmo/mochi-1-preview",
882+
provider: "replicate",
883+
inputs: "A running dog",
884+
parameters: {
885+
num_inference_steps: 10,
886+
seed: 178,
887+
num_frames: 30,
888+
},
889+
});
890+
891+
expect(res).toBeInstanceOf(Blob);
892+
});
847893
},
848894
TIMEOUT
849895
);

0 commit comments

Comments
 (0)