Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ |
| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .cohere import CohereConversationalTask
from .fal_ai import (
FalAIAutomaticSpeechRecognitionTask,
FalAIImageToImageTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
Expand Down Expand Up @@ -78,6 +79,7 @@
"text-to-image": FalAITextToImageTask(),
"text-to-speech": FalAITextToSpeechTask(),
"text-to-video": FalAITextToVideoTask(),
"image-to-image": FalAIImageToImageTask(),
},
"featherless-ai": {
"conversational": FeatherlessConversationalTask(),
Expand Down
127 changes: 85 additions & 42 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session, hf_raise_for_status
from huggingface_hub.utils.logging import get_logger
Expand All @@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/{mapped_model}"


class FalAIQueueTask(TaskProviderHelper, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)

def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"

def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)

request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI."
)

# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""

# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"

status = response_dict.get("status")
logger.info("Generating the output.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")

return get_session().get(result_url, headers=request_params.headers).json()


class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")
Expand Down Expand Up @@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
return get_session().get(url).content


class FalAITextToVideoTask(FalAITask):
class FalAITextToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("text-to-video")

def _prepare_base_url(self, api_key: str) -> str:
if api_key.startswith("hf_"):
return super()._prepare_base_url(api_key)
else:
logger.info(f"Calling '{self.provider}' provider directly.")
return "https://queue.fal.run"

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
Expand All @@ -137,36 +178,38 @@ def get_response(
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content

request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
)

# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
class FalAIImageToImageTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-image")

# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]

status = response_dict.get("status")
logger.info("Generating the video.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")
return payload

response = get_session().get(result_url, headers=request_params.headers).json()
url = _as_dict(response)["video"]["url"]
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["images"][0]["url"]
return get_session().get(url).content
68 changes: 68 additions & 0 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from huggingface_hub.inference._providers.fal_ai import (
_POLLING_INTERVAL,
FalAIAutomaticSpeechRecognitionTask,
FalAIImageToImageTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
Expand Down Expand Up @@ -408,6 +409,73 @@ def test_text_to_video_response(self, mocker):
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
assert response == b"video_content"

def test_image_to_image_payload(self):
helper = FalAIImageToImageTask()
mapping_info = InferenceProviderMapping(
provider="fal-ai",
hf_model_id="stabilityai/sdxl-refiner-1.0",
providerId="fal-ai/sdxl-refiner",
task="image-to-image",
status="live",
)
payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"prompt": "a cat"}, mapping_info)
assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"}

payload = helper._prepare_payload_as_dict(
b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info
)
assert payload == {
"image_url": f"data:image/jpeg;base64,{base64.b64encode(b'dummy_image_data').decode()}",
"prompt": "replace the cat with a dog",
}

def test_image_to_image_response(self, mocker):
helper = FalAIImageToImageTask()
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
mock_session.return_value.get.side_effect = [
# First call: status
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
# Second call: get result
mocker.Mock(json=lambda: {"images": [{"url": "image_url"}]}, headers={"Content-Type": "application/json"}),
# Third call: get image content
mocker.Mock(content=b"image_content"),
]
api_key = helper._prepare_api_key("hf_token")
headers = helper._prepare_headers({}, api_key)
url = helper._prepare_url(api_key, "username/repo_name")

request_params = RequestParameters(
url=url,
headers=headers,
task="image-to-image",
model="username/repo_name",
data=None,
json=None,
)
response = helper.get_response(
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
request_params,
)

# Verify the correct URLs were called
assert mock_session.return_value.get.call_count == 3
mock_session.return_value.get.assert_has_calls(
[
mocker.call(
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
headers=request_params.headers,
),
mocker.call(
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
headers=request_params.headers,
),
mocker.call("image_url"),
]
)
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
assert response == b"image_content"


class TestFeatherlessAIProvider:
def test_prepare_route_chat_completionurl(self):
Expand Down