Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ |
| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ def image_to_image(
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,7 @@ async def image_to_image(
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .cohere import CohereConversationalTask
from .fal_ai import (
FalAIAutomaticSpeechRecognitionTask,
FalAIImageToImageTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
Expand Down Expand Up @@ -78,6 +79,7 @@
"text-to-image": FalAITextToImageTask(),
"text-to-speech": FalAITextToSpeechTask(),
"text-to-video": FalAITextToVideoTask(),
"image-to-image": FalAIImageToImageTask(),
},
"featherless-ai": {
"conversational": FeatherlessConversationalTask(),
Expand Down
144 changes: 102 additions & 42 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import base64
import time
from abc import ABC
from typing import Any, Dict, Optional, Union
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional, Union
from urllib.parse import urlparse

from huggingface_hub import constants
Expand Down Expand Up @@ -32,6 +33,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return f"/{mapped_model}"


class FalAIQueueTask(TaskProviderHelper, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

def __init__(self, task: str):
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)

def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
headers = super()._prepare_headers(headers, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
return headers

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"

def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)

request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI."
)

# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""

# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"

status = response_dict.get("status")
logger.info("Generating the output.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")

return get_session().get(result_url, headers=request_params.headers).json()


class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")
Expand Down Expand Up @@ -110,23 +165,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
return get_session().get(url).content


class FalAITextToVideoTask(FalAITask):
class FalAITextToVideoTask(FalAIQueueTask):
def __init__(self):
super().__init__("text-to-video")

def _prepare_base_url(self, api_key: str) -> str:
if api_key.startswith("hf_"):
return super()._prepare_base_url(api_key)
else:
logger.info(f"Calling '{self.provider}' provider directly.")
return "https://queue.fal.run"

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
if api_key.startswith("hf_"):
# Use the queue subdomain for HF routing
return f"/{mapped_model}?_subdomain=queue"
return f"/{mapped_model}"

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
Expand All @@ -137,36 +179,54 @@ def get_response(
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
response_dict = _as_dict(response)
output = super().get_response(response, request_params)
url = _as_dict(output)["video"]["url"]
return get_session().get(url).content

request_id = response_dict.get("request_id")
if not request_id:
raise ValueError("No request ID found in the response")
if request_params is None:
raise ValueError(
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
)

# extract the base url and query params
parsed_url = urlparse(request_params.url)
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
class FalAIImageToImageTask(FalAIQueueTask):
def __init__(self):
super().__init__("image-to-image")

# extracting the provider model id for status and result urls
# from the response as it might be different from the mapped model in `request_params.url`
model_id = urlparse(response_dict.get("response_url")).path
status_url = f"{base_url}{str(model_id)}/status{query_param}"
result_url = f"{base_url}{str(model_id)}{query_param}"
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
image_url = inputs
else:
image_bytes: bytes
if isinstance(inputs, (str, Path)):
with open(inputs, "rb") as f:
image_bytes = f.read()
elif isinstance(inputs, bytes):
image_bytes = inputs
elif isinstance(inputs, BinaryIO):
image_bytes = inputs.read()
else:
raise TypeError(f"Unsupported input type for image: {type(inputs)}")

image_b64 = base64.b64encode(image_bytes).decode()
content_type = "image/png"
image_url = f"data:{content_type};base64,{image_b64}"
payload: Dict[str, Any] = {
"image_url": image_url,
**filter_none(parameters),
}
if provider_mapping_info.adapter_weights_path is not None:
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=provider_mapping_info.hf_model_id,
revision="main",
filename=provider_mapping_info.adapter_weights_path,
)
payload["loras"] = [{"path": lora_path, "scale": 1}]

status = response_dict.get("status")
logger.info("Generating the video.. this can take several minutes.")
while status != "COMPLETED":
time.sleep(_POLLING_INTERVAL)
status_response = get_session().get(status_url, headers=request_params.headers)
hf_raise_for_status(status_response)
status = status_response.json().get("status")
return payload

response = get_session().get(result_url, headers=request_params.headers).json()
url = _as_dict(response)["video"]["url"]
def get_response(
self,
response: Union[bytes, Dict],
request_params: Optional[RequestParameters] = None,
) -> Any:
output = super().get_response(response, request_params)
url = _as_dict(output)["images"][0]["url"]
return get_session().get(url).content
68 changes: 68 additions & 0 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from huggingface_hub.inference._providers.fal_ai import (
_POLLING_INTERVAL,
FalAIAutomaticSpeechRecognitionTask,
FalAIImageToImageTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
Expand Down Expand Up @@ -404,6 +405,73 @@ def test_text_to_video_response(self, mocker):
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
assert response == b"video_content"

def test_image_to_image_payload(self):
helper = FalAIImageToImageTask()
mapping_info = InferenceProviderMapping(
provider="fal-ai",
hf_model_id="stabilityai/sdxl-refiner-1.0",
providerId="fal-ai/sdxl-refiner",
task="image-to-image",
status="live",
)
payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"prompt": "a cat"}, mapping_info)
assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"}

payload = helper._prepare_payload_as_dict(
b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info
)
assert payload == {
"image_url": f"data:image/png;base64,{base64.b64encode(b'dummy_image_data').decode()}",
"prompt": "replace the cat with a dog",
}

def test_image_to_image_response(self, mocker):
helper = FalAIImageToImageTask()
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
mock_session.return_value.get.side_effect = [
# First call: status
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
# Second call: get result
mocker.Mock(json=lambda: {"images": [{"url": "image_url"}]}, headers={"Content-Type": "application/json"}),
# Third call: get image content
mocker.Mock(content=b"image_content"),
]
api_key = helper._prepare_api_key("hf_token")
headers = helper._prepare_headers({}, api_key)
url = helper._prepare_url(api_key, "username/repo_name")

request_params = RequestParameters(
url=url,
headers=headers,
task="image-to-image",
model="username/repo_name",
data=None,
json=None,
)
response = helper.get_response(
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
request_params,
)

# Verify the correct URLs were called
assert mock_session.return_value.get.call_count == 3
mock_session.return_value.get.assert_has_calls(
[
mocker.call(
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
headers=request_params.headers,
),
mocker.call(
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
headers=request_params.headers,
),
mocker.call("image_url"),
]
)
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
assert response == b"image_content"


class TestFeatherlessAIProvider:
def test_prepare_route_chat_completionurl(self):
Expand Down