Skip to content

Commit fa7ca44

Browse files
committed
feat. Refactor the NovitaTextToVideoTask using the async API.
1 parent 29150f4 commit fa7ca44

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

packages/inference/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Currently, we support the following providers:
5252
- [Fireworks AI](https://fireworks.ai)
5353
- [Hyperbolic](https://hyperbolic.xyz)
5454
- [Nebius](https://studio.nebius.ai)
55-
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
55+
- [Novita](https://novita.ai)
5656
- [Nscale](https://nscale.com)
5757
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
5858
- [Replicate](https://replicate.com)
@@ -93,6 +93,7 @@ Only a subset of models are supported when requesting third-party providers. You
9393
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
9494
- [Groq supported models](https://console.groq.com/docs/models)
9595
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
96+
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
9697

9798
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
9899
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
120120
novita: {
121121
conversational: new Novita.NovitaConversationalTask(),
122122
"text-generation": new Novita.NovitaTextGenerationTask(),
123+
"text-to-video": new Novita.NovitaTextToVideoTask(),
123124
},
124125
nscale: {
125126
"text-to-image": new Nscale.NscaleTextToImageTask(),

packages/inference/src/providers/novita.ts

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import { InferenceOutputError } from "../lib/InferenceOutputError";
1818
import { isUrl } from "../lib/isUrl";
1919
import type { BodyParams, UrlParams } from "../types";
20+
import { delay } from "../utils/delay";
2021
import { omit } from "../utils/omit";
2122
import {
2223
BaseConversationalTask,
@@ -26,11 +27,11 @@ import {
2627
} from "./providerHelper";
2728

2829
const NOVITA_API_BASE_URL = "https://api.novita.ai";
29-
export interface NovitaOutput {
30-
video: {
31-
video_url: string;
32-
};
30+
31+
export interface NovitaAsyncAPIOutput {
32+
task_id: string;
3333
}
34+
3435
export class NovitaTextGenerationTask extends BaseTextGenerationTask {
3536
constructor() {
3637
super("novita", NOVITA_API_BASE_URL);
@@ -50,38 +51,88 @@ export class NovitaConversationalTask extends BaseConversationalTask {
5051
return "/v3/openai/chat/completions";
5152
}
5253
}
54+
5355
export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToVideoTaskHelper {
5456
constructor() {
5557
super("novita", NOVITA_API_BASE_URL);
5658
}
5759

58-
makeRoute(params: UrlParams): string {
59-
return `/v3/hf/${params.model}`;
60+
override makeRoute(params: UrlParams): string {
61+
if (params.authMethod !== "provider-key") {
62+
return `/v3/async/${params.model}?_subdomain=queue`;
63+
}
64+
return `/v3/async/${params.model}`;
6065
}
6166

62-
preparePayload(params: BodyParams): Record<string, unknown> {
67+
override preparePayload(params: BodyParams): Record<string, unknown> {
68+
const { num_inference_steps, ...restParameters } = params.args.parameters as Record<string, unknown>;
6369
return {
6470
...omit(params.args, ["inputs", "parameters"]),
65-
...(params.args.parameters as Record<string, unknown>),
71+
...restParameters,
72+
steps: num_inference_steps,
6673
prompt: params.args.inputs,
6774
};
6875
}
69-
override async getResponse(response: NovitaOutput): Promise<Blob> {
76+
77+
override async getResponse(
78+
response: NovitaAsyncAPIOutput,
79+
url?: string,
80+
headers?: Record<string, string>
81+
): Promise<Blob> {
82+
if (!url || !headers) {
83+
throw new InferenceOutputError("URL and headers are required for text-to-video task");
84+
}
85+
const taskId = response.task_id;
86+
if (!taskId) {
87+
throw new InferenceOutputError("No task ID found in the response");
88+
}
89+
90+
const parsedUrl = new URL(url);
91+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
92+
parsedUrl.host === "router.huggingface.co" ? "/novita" : ""
93+
}`;
94+
const queryParams = parsedUrl.search;
95+
const resultUrl = `${baseUrl}/v3/async/task-result${queryParams ? queryParams + '&' : '?'}task_id=${taskId}`;
96+
97+
let status = '';
98+
let taskResult = undefined;
99+
100+
while (status !== 'TASK_STATUS_SUCCEED' && status !== 'TASK_STATUS_FAILED') {
101+
await delay(500);
102+
const resultResponse = await fetch(resultUrl, { headers });
103+
if (!resultResponse.ok) {
104+
throw new InferenceOutputError("Failed to fetch task result");
105+
}
106+
try {
107+
taskResult = await resultResponse.json();
108+
status = taskResult.task.status;
109+
} catch (error) {
110+
throw new InferenceOutputError("Failed to parse task result");
111+
}
112+
}
113+
114+
if (status === 'TASK_STATUS_FAILED') {
115+
throw new InferenceOutputError("Task failed");
116+
}
117+
118+
// There will be at most one video in the response.
70119
const isValidOutput =
71-
typeof response === "object" &&
72-
!!response &&
73-
"video" in response &&
74-
typeof response.video === "object" &&
75-
!!response.video &&
76-
"video_url" in response.video &&
77-
typeof response.video.video_url === "string" &&
78-
isUrl(response.video.video_url);
120+
typeof taskResult === "object" &&
121+
!!taskResult &&
122+
"videos" in taskResult &&
123+
typeof taskResult.videos === "object" &&
124+
!!taskResult.videos &&
125+
Array.isArray(taskResult.videos) &&
126+
taskResult.videos.length > 0 &&
127+
"video_url" in taskResult.videos[0] &&
128+
typeof taskResult.videos[0].video_url === "string" &&
129+
isUrl(taskResult.videos[0].video_url);
79130

80131
if (!isValidOutput) {
81-
throw new InferenceOutputError("Expected { video: { video_url: string } }");
132+
throw new InferenceOutputError("Expected { videos: [{ video_url: string }] }");
82133
}
83134

84-
const urlResponse = await fetch(response.video.video_url);
135+
const urlResponse = await fetch(taskResult.videos[0].video_url);
85136
return await urlResponse.blob();
86137
}
87138
}

packages/inference/src/tasks/cv/textToVideo.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping";
33
import { getProviderHelper } from "../../lib/getProviderHelper";
44
import { makeRequestOptions } from "../../lib/makeRequestOptions";
55
import type { FalAiQueueOutput } from "../../providers/fal-ai";
6-
import type { NovitaOutput } from "../../providers/novita";
6+
import type { NovitaAsyncAPIOutput } from "../../providers/novita";
77
import type { ReplicateOutput } from "../../providers/replicate";
88
import type { BaseArgs, Options } from "../../types";
99
import { innerRequest } from "../../utils/request";
@@ -15,7 +15,7 @@ export type TextToVideoOutput = Blob;
1515
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
1616
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1717
const providerHelper = getProviderHelper(provider, "text-to-video");
18-
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
18+
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaAsyncAPIOutput>(
1919
args,
2020
providerHelper,
2121
{

0 commit comments

Comments
 (0)