66
77from huggingface_hub import constants
88from huggingface_hub .hf_api import InferenceProviderMapping
9- from huggingface_hub .inference ._common import RequestParameters , _as_dict
9+ from huggingface_hub .inference ._common import RequestParameters , _as_dict , _as_url
1010from huggingface_hub .inference ._providers ._common import TaskProviderHelper , filter_none
1111from huggingface_hub .utils import get_session , hf_raise_for_status
1212from huggingface_hub .utils .logging import get_logger
@@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
3232 return f"/{ mapped_model } "
3333
3434
35+ class FalAIQueueTask (TaskProviderHelper , ABC ):
36+ def __init__ (self , task : str ):
37+ super ().__init__ (provider = "fal-ai" , base_url = "https://queue.fal.run" , task = task )
38+
39+ def _prepare_headers (self , headers : Dict , api_key : str ) -> Dict :
40+ headers = super ()._prepare_headers (headers , api_key )
41+ if not api_key .startswith ("hf_" ):
42+ headers ["authorization" ] = f"Key { api_key } "
43+ return headers
44+
45+ def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
46+ if api_key .startswith ("hf_" ):
47+ # Use the queue subdomain for HF routing
48+ return f"/{ mapped_model } ?_subdomain=queue"
49+ return f"/{ mapped_model } "
50+
51+ def get_response (
52+ self ,
53+ response : Union [bytes , Dict ],
54+ request_params : Optional [RequestParameters ] = None ,
55+ ) -> Any :
56+ response_dict = _as_dict (response )
57+
58+ request_id = response_dict .get ("request_id" )
59+ if not request_id :
60+ raise ValueError ("No request ID found in the response" )
61+ if request_params is None :
62+ raise ValueError (
63+ f"A `RequestParameters` object should be provided to get { self .task } responses with Fal AI."
64+ )
65+
66+ # extract the base url and query params
67+ parsed_url = urlparse (request_params .url )
68+ # a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
69+ base_url = f"{ parsed_url .scheme } ://{ parsed_url .netloc } { '/fal-ai' if parsed_url .netloc == 'router.huggingface.co' else '' } "
70+ query_param = f"?{ parsed_url .query } " if parsed_url .query else ""
71+
72+ # extracting the provider model id for status and result urls
73+ # from the response as it might be different from the mapped model in `request_params.url`
74+ model_id = urlparse (response_dict .get ("response_url" )).path
75+ status_url = f"{ base_url } { str (model_id )} /status{ query_param } "
76+ result_url = f"{ base_url } { str (model_id )} { query_param } "
77+
78+ status = response_dict .get ("status" )
79+ logger .info ("Generating the output.. this can take several minutes." )
80+ while status != "COMPLETED" :
81+ time .sleep (_POLLING_INTERVAL )
82+ status_response = get_session ().get (status_url , headers = request_params .headers )
83+ hf_raise_for_status (status_response )
84+ status = status_response .json ().get ("status" )
85+
86+ return get_session ().get (result_url , headers = request_params .headers ).json ()
87+
88+
3589class FalAIAutomaticSpeechRecognitionTask (FalAITask ):
3690 def __init__ (self ):
3791 super ().__init__ ("automatic-speech-recognition" )
@@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
110164 return get_session ().get (url ).content
111165
112166
113- class FalAITextToVideoTask (FalAITask ):
167+ class FalAITextToVideoTask (FalAIQueueTask ):
114168 def __init__ (self ):
115169 super ().__init__ ("text-to-video" )
116170
117- def _prepare_base_url (self , api_key : str ) -> str :
118- if api_key .startswith ("hf_" ):
119- return super ()._prepare_base_url (api_key )
120- else :
121- logger .info (f"Calling '{ self .provider } ' provider directly." )
122- return "https://queue.fal.run"
123-
124- def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
125- if api_key .startswith ("hf_" ):
126- # Use the queue subdomain for HF routing
127- return f"/{ mapped_model } ?_subdomain=queue"
128- return f"/{ mapped_model } "
129-
130171 def _prepare_payload_as_dict (
131172 self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
132173 ) -> Optional [Dict ]:
@@ -137,36 +178,38 @@ def get_response(
137178 response : Union [bytes , Dict ],
138179 request_params : Optional [RequestParameters ] = None ,
139180 ) -> Any :
140- response_dict = _as_dict (response )
181+ output = super ().get_response (response , request_params )
182+ url = _as_dict (output )["video" ]["url" ]
183+ return get_session ().get (url ).content
141184
142- request_id = response_dict .get ("request_id" )
143- if not request_id :
144- raise ValueError ("No request ID found in the response" )
145- if request_params is None :
146- raise ValueError (
147- "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
148- )
149185
150- # extract the base url and query params
151- parsed_url = urlparse (request_params .url )
152- # a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
153- base_url = f"{ parsed_url .scheme } ://{ parsed_url .netloc } { '/fal-ai' if parsed_url .netloc == 'router.huggingface.co' else '' } "
154- query_param = f"?{ parsed_url .query } " if parsed_url .query else ""
186+ class FalAIImageToImageTask (FalAIQueueTask ):
187+ def __init__ (self ):
188+ super ().__init__ ("image-to-image" )
155189
156- # extracting the provider model id for status and result urls
157- # from the response as it might be different from the mapped model in `request_params.url`
158- model_id = urlparse (response_dict .get ("response_url" )).path
159- status_url = f"{ base_url } { str (model_id )} /status{ query_param } "
160- result_url = f"{ base_url } { str (model_id )} { query_param } "
190+ def _prepare_payload_as_dict (
191+ self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
192+ ) -> Optional [Dict ]:
193+ image_url = _as_url (inputs , default_mime_type = "image/jpeg" )
194+ payload : Dict [str , Any ] = {
195+ "image_url" : image_url ,
196+ ** filter_none (parameters ),
197+ }
198+ if provider_mapping_info .adapter_weights_path is not None :
199+ lora_path = constants .HUGGINGFACE_CO_URL_TEMPLATE .format (
200+ repo_id = provider_mapping_info .hf_model_id ,
201+ revision = "main" ,
202+ filename = provider_mapping_info .adapter_weights_path ,
203+ )
204+ payload ["loras" ] = [{"path" : lora_path , "scale" : 1 }]
161205
162- status = response_dict .get ("status" )
163- logger .info ("Generating the video.. this can take several minutes." )
164- while status != "COMPLETED" :
165- time .sleep (_POLLING_INTERVAL )
166- status_response = get_session ().get (status_url , headers = request_params .headers )
167- hf_raise_for_status (status_response )
168- status = status_response .json ().get ("status" )
206+ return payload
169207
170- response = get_session ().get (result_url , headers = request_params .headers ).json ()
171- url = _as_dict (response )["video" ]["url" ]
208+ def get_response (
209+ self ,
210+ response : Union [bytes , Dict ],
211+ request_params : Optional [RequestParameters ] = None ,
212+ ) -> Any :
213+ output = super ().get_response (response , request_params )
214+ url = _as_dict (output )["images" ][0 ]["url" ]
172215 return get_session ().get (url ).content
0 commit comments