11import base64
2+ import time
23from abc import ABC
34from typing import Any , Dict , Optional , Union
45
5- from huggingface_hub .inference ._common import _as_dict
6+ from huggingface_hub .inference ._common import RequestParameters , _as_dict
67from huggingface_hub .inference ._providers ._common import TaskProviderHelper , filter_none
7- from huggingface_hub .utils import get_session
8+ from huggingface_hub .utils import get_session , hf_raise_for_status
9+ from huggingface_hub .utils .logging import get_logger
10+
11+
12+ logger = get_logger (__name__ )
13+
14+ # Arbitrary polling interval
15+ _POLLING_INTERVAL = 2.0
816
917
1018class FalAITask (TaskProviderHelper , ABC ):
@@ -17,7 +25,7 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
1725 headers ["authorization" ] = f"Key { api_key } "
1826 return headers
1927
20- def _prepare_route (self , mapped_model : str ) -> str :
28+ def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
2129 return f"/{ mapped_model } "
2230
2331
@@ -41,7 +49,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
4149
4250 return {"audio_url" : audio_url , ** filter_none (parameters )}
4351
44- def get_response (self , response : Union [bytes , Dict ]) -> Any :
52+ def get_response (self , response : Union [bytes , Dict ], request_params : Optional [ RequestParameters ] = None ) -> Any :
4553 text = _as_dict (response )["text" ]
4654 if not isinstance (text , str ):
4755 raise ValueError (f"Unexpected output format from FalAI API. Expected string, got { type (text )} ." )
@@ -61,7 +69,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
6169 }
6270 return {"prompt" : inputs , ** parameters }
6371
64- def get_response (self , response : Union [bytes , Dict ]) -> Any :
72+ def get_response (self , response : Union [bytes , Dict ], request_params : Optional [ RequestParameters ] = None ) -> Any :
6573 url = _as_dict (response )["images" ][0 ]["url" ]
6674 return get_session ().get (url ).content
6775
@@ -73,7 +81,7 @@ def __init__(self):
7381 def _prepare_payload_as_dict (self , inputs : Any , parameters : Dict , mapped_model : str ) -> Optional [Dict ]:
7482 return {"lyrics" : inputs , ** filter_none (parameters )}
7583
76- def get_response (self , response : Union [bytes , Dict ]) -> Any :
84+ def get_response (self , response : Union [bytes , Dict ], request_params : Optional [ RequestParameters ] = None ) -> Any :
7785 url = _as_dict (response )["audio" ]["url" ]
7886 return get_session ().get (url ).content
7987
@@ -82,9 +90,52 @@ class FalAITextToVideoTask(FalAITask):
8290 def __init__ (self ):
8391 super ().__init__ ("text-to-video" )
8492
93+ def _prepare_base_url (self , api_key : str ) -> str :
94+ if api_key .startswith ("hf_" ):
95+ return super ()._prepare_base_url (api_key )
96+ else :
97+ logger .info (f"Calling '{ self .provider } ' provider directly." )
98+ return "https://queue.fal.run"
99+
100+ def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
101+ if api_key .startswith ("hf_" ):
102+ # Use the queue subdomain for HF routing
103+ return f"/{ mapped_model } ?_subdomain=queue"
104+ return f"/{ mapped_model } "
105+
85106 def _prepare_payload_as_dict (self , inputs : Any , parameters : Dict , mapped_model : str ) -> Optional [Dict ]:
86107 return {"prompt" : inputs , ** filter_none (parameters )}
87108
88- def get_response (self , response : Union [bytes , Dict ]) -> Any :
109+ def get_response (
110+ self ,
111+ response : Union [bytes , Dict ],
112+ request_params : Optional [RequestParameters ] = None ,
113+ ) -> Any :
114+ response_dict = _as_dict (response )
115+
116+ request_id = response_dict .get ("request_id" )
117+ if not request_id :
118+ raise ValueError ("No request ID found in the response" )
119+ if request_params is None :
120+ raise ValueError (
121+ "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
122+ )
123+
124+ # extract the base url and query params
125+ base_url = request_params .url .split ("?" )[0 ] # or parsed.scheme + "://" + parsed.netloc + parsed.path ?
126+ query = "?_subdomain=queue" if request_params .url .endswith ("_subdomain=queue" ) else ""
127+
128+ status_url = f"{ base_url } /requests/{ request_id } /status{ query } "
129+ result_url = f"{ base_url } /requests/{ request_id } { query } "
130+
131+ status = response_dict .get ("status" )
132+ logger .info ("Generating the video.. this can take several minutes." )
133+ while status != "COMPLETED" :
134+ time .sleep (_POLLING_INTERVAL )
135+ status_response = get_session ().get (status_url , headers = request_params .headers )
136+ hf_raise_for_status (status_response )
137+ status = status_response .json ().get ("status" )
138+
139+ response = get_session ().get (result_url , headers = request_params .headers ).json ()
89140 url = _as_dict (response )["video" ]["url" ]
90141 return get_session ().get (url ).content
0 commit comments