1717import { InferenceOutputError } from "../lib/InferenceOutputError" ;
1818import { isUrl } from "../lib/isUrl" ;
1919import type { BodyParams , UrlParams } from "../types" ;
20+ import { delay } from "../utils/delay" ;
2021import { omit } from "../utils/omit" ;
2122import {
2223 BaseConversationalTask ,
@@ -26,11 +27,11 @@ import {
2627} from "./providerHelper" ;
2728
2829const NOVITA_API_BASE_URL = "https://api.novita.ai" ;
29- export interface NovitaOutput {
30- video : {
31- video_url : string ;
32- } ;
30+
31+ export interface NovitaAsyncAPIOutput {
32+ task_id : string ;
3333}
34+
3435export class NovitaTextGenerationTask extends BaseTextGenerationTask {
3536 constructor ( ) {
3637 super ( "novita" , NOVITA_API_BASE_URL ) ;
@@ -50,38 +51,88 @@ export class NovitaConversationalTask extends BaseConversationalTask {
5051 return "/v3/openai/chat/completions" ;
5152 }
5253}
54+
5355export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToVideoTaskHelper {
5456 constructor ( ) {
5557 super ( "novita" , NOVITA_API_BASE_URL ) ;
5658 }
5759
58- makeRoute ( params : UrlParams ) : string {
59- return `/v3/hf/${ params . model } ` ;
60+ override makeRoute ( params : UrlParams ) : string {
61+ if ( params . authMethod !== "provider-key" ) {
62+ return `/v3/async/${ params . model } ?_subdomain=queue` ;
63+ }
64+ return `/v3/async/${ params . model } ` ;
6065 }
6166
62- preparePayload ( params : BodyParams ) : Record < string , unknown > {
67+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
68+ const { num_inference_steps, ...restParameters } = params . args . parameters as Record < string , unknown > ;
6369 return {
6470 ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
65- ...( params . args . parameters as Record < string , unknown > ) ,
71+ ...restParameters ,
72+ steps : num_inference_steps ,
6673 prompt : params . args . inputs ,
6774 } ;
6875 }
69- override async getResponse ( response : NovitaOutput ) : Promise < Blob > {
76+
77+ override async getResponse (
78+ response : NovitaAsyncAPIOutput ,
79+ url ?: string ,
80+ headers ?: Record < string , string >
81+ ) : Promise < Blob > {
82+ if ( ! url || ! headers ) {
83+ throw new InferenceOutputError ( "URL and headers are required for text-to-video task" ) ;
84+ }
85+ const taskId = response . task_id ;
86+ if ( ! taskId ) {
87+ throw new InferenceOutputError ( "No task ID found in the response" ) ;
88+ }
89+
90+ const parsedUrl = new URL ( url ) ;
91+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
92+ parsedUrl . host === "router.huggingface.co" ? "/novita" : ""
93+ } `;
94+ const queryParams = parsedUrl . search ;
95+ const resultUrl = `${ baseUrl } /v3/async/task-result${ queryParams ? queryParams + '&' : '?' } task_id=${ taskId } ` ;
96+
97+ let status = '' ;
98+ let taskResult = undefined ;
99+
100+ while ( status !== 'TASK_STATUS_SUCCEED' && status !== 'TASK_STATUS_FAILED' ) {
101+ await delay ( 500 ) ;
102+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
103+ if ( ! resultResponse . ok ) {
104+ throw new InferenceOutputError ( "Failed to fetch task result" ) ;
105+ }
106+ try {
107+ taskResult = await resultResponse . json ( ) ;
108+ status = taskResult . task . status ;
109+ } catch ( error ) {
110+ throw new InferenceOutputError ( "Failed to parse task result" ) ;
111+ }
112+ }
113+
114+ if ( status === 'TASK_STATUS_FAILED' ) {
115+ throw new InferenceOutputError ( "Task failed" ) ;
116+ }
117+
118+ // There will be at most one video in the response.
70119 const isValidOutput =
71- typeof response === "object" &&
72- ! ! response &&
73- "video" in response &&
74- typeof response . video === "object" &&
75- ! ! response . video &&
76- "video_url" in response . video &&
77- typeof response . video . video_url === "string" &&
78- isUrl ( response . video . video_url ) ;
120+ typeof taskResult === "object" &&
121+ ! ! taskResult &&
122+ "videos" in taskResult &&
123+ typeof taskResult . videos === "object" &&
124+ ! ! taskResult . videos &&
125+ Array . isArray ( taskResult . videos ) &&
126+ taskResult . videos . length > 0 &&
127+ "video_url" in taskResult . videos [ 0 ] &&
128+ typeof taskResult . videos [ 0 ] . video_url === "string" &&
129+ isUrl ( taskResult . videos [ 0 ] . video_url ) ;
79130
80131 if ( ! isValidOutput ) {
81- throw new InferenceOutputError ( "Expected { video: { video_url: string } }" ) ;
132+ throw new InferenceOutputError ( "Expected { videos: [ { video_url: string }] }" ) ;
82133 }
83134
84- const urlResponse = await fetch ( response . video . video_url ) ;
135+ const urlResponse = await fetch ( taskResult . videos [ 0 ] . video_url ) ;
85136 return await urlResponse . blob ( ) ;
86137 }
87138}
0 commit comments