6
6
7
7
from huggingface_hub import constants
8
8
from huggingface_hub .hf_api import InferenceProviderMapping
9
- from huggingface_hub .inference ._common import RequestParameters , _as_dict
9
+ from huggingface_hub .inference ._common import RequestParameters , _as_dict , _as_url
10
10
from huggingface_hub .inference ._providers ._common import TaskProviderHelper , filter_none
11
11
from huggingface_hub .utils import get_session , hf_raise_for_status
12
12
from huggingface_hub .utils .logging import get_logger
@@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
32
32
return f"/{ mapped_model } "
33
33
34
34
35
+ class FalAIQueueTask (TaskProviderHelper , ABC ):
36
+ def __init__ (self , task : str ):
37
+ super ().__init__ (provider = "fal-ai" , base_url = "https://queue.fal.run" , task = task )
38
+
39
+ def _prepare_headers (self , headers : Dict , api_key : str ) -> Dict :
40
+ headers = super ()._prepare_headers (headers , api_key )
41
+ if not api_key .startswith ("hf_" ):
42
+ headers ["authorization" ] = f"Key { api_key } "
43
+ return headers
44
+
45
+ def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
46
+ if api_key .startswith ("hf_" ):
47
+ # Use the queue subdomain for HF routing
48
+ return f"/{ mapped_model } ?_subdomain=queue"
49
+ return f"/{ mapped_model } "
50
+
51
+ def get_response (
52
+ self ,
53
+ response : Union [bytes , Dict ],
54
+ request_params : Optional [RequestParameters ] = None ,
55
+ ) -> Any :
56
+ response_dict = _as_dict (response )
57
+
58
+ request_id = response_dict .get ("request_id" )
59
+ if not request_id :
60
+ raise ValueError ("No request ID found in the response" )
61
+ if request_params is None :
62
+ raise ValueError (
63
+ f"A `RequestParameters` object should be provided to get { self .task } responses with Fal AI."
64
+ )
65
+
66
+ # extract the base url and query params
67
+ parsed_url = urlparse (request_params .url )
68
+ # a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
69
+ base_url = f"{ parsed_url .scheme } ://{ parsed_url .netloc } { '/fal-ai' if parsed_url .netloc == 'router.huggingface.co' else '' } "
70
+ query_param = f"?{ parsed_url .query } " if parsed_url .query else ""
71
+
72
+ # extracting the provider model id for status and result urls
73
+ # from the response as it might be different from the mapped model in `request_params.url`
74
+ model_id = urlparse (response_dict .get ("response_url" )).path
75
+ status_url = f"{ base_url } { str (model_id )} /status{ query_param } "
76
+ result_url = f"{ base_url } { str (model_id )} { query_param } "
77
+
78
+ status = response_dict .get ("status" )
79
+ logger .info ("Generating the output.. this can take several minutes." )
80
+ while status != "COMPLETED" :
81
+ time .sleep (_POLLING_INTERVAL )
82
+ status_response = get_session ().get (status_url , headers = request_params .headers )
83
+ hf_raise_for_status (status_response )
84
+ status = status_response .json ().get ("status" )
85
+
86
+ return get_session ().get (result_url , headers = request_params .headers ).json ()
87
+
88
+
35
89
class FalAIAutomaticSpeechRecognitionTask (FalAITask ):
36
90
def __init__ (self ):
37
91
super ().__init__ ("automatic-speech-recognition" )
@@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
110
164
return get_session ().get (url ).content
111
165
112
166
113
- class FalAITextToVideoTask (FalAITask ):
167
+ class FalAITextToVideoTask (FalAIQueueTask ):
114
168
def __init__ (self ):
115
169
super ().__init__ ("text-to-video" )
116
170
117
- def _prepare_base_url (self , api_key : str ) -> str :
118
- if api_key .startswith ("hf_" ):
119
- return super ()._prepare_base_url (api_key )
120
- else :
121
- logger .info (f"Calling '{ self .provider } ' provider directly." )
122
- return "https://queue.fal.run"
123
-
124
- def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
125
- if api_key .startswith ("hf_" ):
126
- # Use the queue subdomain for HF routing
127
- return f"/{ mapped_model } ?_subdomain=queue"
128
- return f"/{ mapped_model } "
129
-
130
171
def _prepare_payload_as_dict (
131
172
self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
132
173
) -> Optional [Dict ]:
@@ -137,36 +178,38 @@ def get_response(
137
178
response : Union [bytes , Dict ],
138
179
request_params : Optional [RequestParameters ] = None ,
139
180
) -> Any :
140
- response_dict = _as_dict (response )
181
+ output = super ().get_response (response , request_params )
182
+ url = _as_dict (output )["video" ]["url" ]
183
+ return get_session ().get (url ).content
141
184
142
- request_id = response_dict .get ("request_id" )
143
- if not request_id :
144
- raise ValueError ("No request ID found in the response" )
145
- if request_params is None :
146
- raise ValueError (
147
- "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
148
- )
149
185
150
- # extract the base url and query params
151
- parsed_url = urlparse (request_params .url )
152
- # a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
153
- base_url = f"{ parsed_url .scheme } ://{ parsed_url .netloc } { '/fal-ai' if parsed_url .netloc == 'router.huggingface.co' else '' } "
154
- query_param = f"?{ parsed_url .query } " if parsed_url .query else ""
186
+ class FalAIImageToImageTask (FalAIQueueTask ):
187
+ def __init__ (self ):
188
+ super ().__init__ ("image-to-image" )
155
189
156
- # extracting the provider model id for status and result urls
157
- # from the response as it might be different from the mapped model in `request_params.url`
158
- model_id = urlparse (response_dict .get ("response_url" )).path
159
- status_url = f"{ base_url } { str (model_id )} /status{ query_param } "
160
- result_url = f"{ base_url } { str (model_id )} { query_param } "
190
+ def _prepare_payload_as_dict (
191
+ self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
192
+ ) -> Optional [Dict ]:
193
+ image_url = _as_url (inputs , default_mime_type = "image/jpeg" )
194
+ payload : Dict [str , Any ] = {
195
+ "image_url" : image_url ,
196
+ ** filter_none (parameters ),
197
+ }
198
+ if provider_mapping_info .adapter_weights_path is not None :
199
+ lora_path = constants .HUGGINGFACE_CO_URL_TEMPLATE .format (
200
+ repo_id = provider_mapping_info .hf_model_id ,
201
+ revision = "main" ,
202
+ filename = provider_mapping_info .adapter_weights_path ,
203
+ )
204
+ payload ["loras" ] = [{"path" : lora_path , "scale" : 1 }]
161
205
162
- status = response_dict .get ("status" )
163
- logger .info ("Generating the video.. this can take several minutes." )
164
- while status != "COMPLETED" :
165
- time .sleep (_POLLING_INTERVAL )
166
- status_response = get_session ().get (status_url , headers = request_params .headers )
167
- hf_raise_for_status (status_response )
168
- status = status_response .json ().get ("status" )
206
+ return payload
169
207
170
- response = get_session ().get (result_url , headers = request_params .headers ).json ()
171
- url = _as_dict (response )["video" ]["url" ]
208
+ def get_response (
209
+ self ,
210
+ response : Union [bytes , Dict ],
211
+ request_params : Optional [RequestParameters ] = None ,
212
+ ) -> Any :
213
+ output = super ().get_response (response , request_params )
214
+ url = _as_dict (output )["images" ][0 ]["url" ]
172
215
return get_session ().get (url ).content
0 commit comments