Skip to content

Commit 3720a50

Browse files
hanouticelinaWauplin
authored andcommitted
[Inference] Support image to video task (#3289)
1 parent bb5e4c7 commit 3720a50

File tree

6 files changed

+269
-0
lines changed

6 files changed

+269
-0
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
219219
| [`~InferenceClient.image_classification`] |||||||||||||||
220220
| [`~InferenceClient.image_segmentation`] |||||||||||||||
221221
| [`~InferenceClient.image_to_image`] |||||||||||||||
222+
| [`~InferenceClient.image_to_video`] |||||||||||||||
222223
| [`~InferenceClient.image_to_text`] |||||||||||||||
223224
| [`~InferenceClient.object_detection`] ||||||||||||||| ||
224225
| [`~InferenceClient.question_answering`] |||||||||||||||

src/huggingface_hub/inference/_client.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
ImageSegmentationSubtask,
8282
ImageToImageTargetSize,
8383
ImageToTextOutput,
84+
ImageToVideoTargetSize,
8485
ObjectDetectionOutputElement,
8586
Padding,
8687
QuestionAnsweringOutputElement,
@@ -1339,6 +1340,85 @@ def image_to_image(
13391340
response = provider_helper.get_response(response, request_parameters)
13401341
return _bytes_to_image(response)
13411342

1343+
def image_to_video(
1344+
self,
1345+
image: ContentT,
1346+
*,
1347+
model: Optional[str] = None,
1348+
prompt: Optional[str] = None,
1349+
negative_prompt: Optional[str] = None,
1350+
num_frames: Optional[float] = None,
1351+
num_inference_steps: Optional[int] = None,
1352+
guidance_scale: Optional[float] = None,
1353+
seed: Optional[int] = None,
1354+
target_size: Optional[ImageToVideoTargetSize] = None,
1355+
**kwargs,
1356+
) -> bytes:
1357+
"""
1358+
Generate a video from an input image.
1359+
1360+
Args:
1361+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1362+
The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1363+
model (`str`, *optional*):
1364+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1365+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1366+
prompt (`str`, *optional*):
1367+
The text prompt to guide the video generation.
1368+
negative_prompt (`str`, *optional*):
1369+
One prompt to guide what NOT to include in video generation.
1370+
num_frames (`float`, *optional*):
1371+
The num_frames parameter determines how many video frames are generated.
1372+
num_inference_steps (`int`, *optional*):
1373+
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1374+
quality image at the expense of slower inference.
1375+
guidance_scale (`float`, *optional*):
1376+
For diffusion models. A higher guidance scale value encourages the model to generate videos closely
1377+
linked to the text prompt at the expense of lower image quality.
1378+
seed (`int`, *optional*):
1379+
The seed to use for the video generation.
1380+
target_size (`ImageToVideoTargetSize`, *optional*):
1381+
The size in pixel of the output video frames.
1382+
num_inference_steps (`int`, *optional*):
1383+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
1384+
expense of slower inference.
1385+
seed (`int`, *optional*):
1386+
Seed for the random number generator.
1387+
1388+
Returns:
1389+
`bytes`: The generated video.
1390+
1391+
Examples:
1392+
```py
1393+
>>> from huggingface_hub import InferenceClient
1394+
>>> client = InferenceClient()
1395+
>>> video = client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
1396+
>>> with open("tiger.mp4", "wb") as f:
1397+
... f.write(video)
1398+
```
1399+
"""
1400+
model_id = model or self.model
1401+
provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
1402+
request_parameters = provider_helper.prepare_request(
1403+
inputs=image,
1404+
parameters={
1405+
"prompt": prompt,
1406+
"negative_prompt": negative_prompt,
1407+
"num_frames": num_frames,
1408+
"num_inference_steps": num_inference_steps,
1409+
"guidance_scale": guidance_scale,
1410+
"seed": seed,
1411+
"target_size": target_size,
1412+
**kwargs,
1413+
},
1414+
headers=self.headers,
1415+
model=model_id,
1416+
api_key=self.token,
1417+
)
1418+
response = self._inner_post(request_parameters)
1419+
response = provider_helper.get_response(response, request_parameters)
1420+
return response
1421+
13421422
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
13431423
"""
13441424
Takes an input image and return text.

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
ImageSegmentationSubtask,
6767
ImageToImageTargetSize,
6868
ImageToTextOutput,
69+
ImageToVideoTargetSize,
6970
ObjectDetectionOutputElement,
7071
Padding,
7172
QuestionAnsweringOutputElement,
@@ -1385,6 +1386,86 @@ async def image_to_image(
13851386
response = provider_helper.get_response(response, request_parameters)
13861387
return _bytes_to_image(response)
13871388

1389+
async def image_to_video(
1390+
self,
1391+
image: ContentT,
1392+
*,
1393+
model: Optional[str] = None,
1394+
prompt: Optional[str] = None,
1395+
negative_prompt: Optional[str] = None,
1396+
num_frames: Optional[float] = None,
1397+
num_inference_steps: Optional[int] = None,
1398+
guidance_scale: Optional[float] = None,
1399+
seed: Optional[int] = None,
1400+
target_size: Optional[ImageToVideoTargetSize] = None,
1401+
**kwargs,
1402+
) -> bytes:
1403+
"""
1404+
Generate a video from an input image.
1405+
1406+
Args:
1407+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1408+
The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
1409+
model (`str`, *optional*):
1410+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1411+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1412+
prompt (`str`, *optional*):
1413+
The text prompt to guide the video generation.
1414+
negative_prompt (`str`, *optional*):
1415+
One prompt to guide what NOT to include in video generation.
1416+
num_frames (`float`, *optional*):
1417+
The num_frames parameter determines how many video frames are generated.
1418+
num_inference_steps (`int`, *optional*):
1419+
For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher
1420+
quality image at the expense of slower inference.
1421+
guidance_scale (`float`, *optional*):
1422+
For diffusion models. A higher guidance scale value encourages the model to generate videos closely
1423+
linked to the text prompt at the expense of lower image quality.
1424+
seed (`int`, *optional*):
1425+
The seed to use for the video generation.
1426+
target_size (`ImageToVideoTargetSize`, *optional*):
1427+
The size in pixel of the output video frames.
1428+
num_inference_steps (`int`, *optional*):
1429+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
1430+
expense of slower inference.
1431+
seed (`int`, *optional*):
1432+
Seed for the random number generator.
1433+
1434+
Returns:
1435+
`bytes`: The generated video.
1436+
1437+
Examples:
1438+
```py
1439+
# Must be run in an async context
1440+
>>> from huggingface_hub import AsyncInferenceClient
1441+
>>> client = AsyncInferenceClient()
1442+
>>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger")
1443+
>>> with open("tiger.mp4", "wb") as f:
1444+
... f.write(video)
1445+
```
1446+
"""
1447+
model_id = model or self.model
1448+
provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id)
1449+
request_parameters = provider_helper.prepare_request(
1450+
inputs=image,
1451+
parameters={
1452+
"prompt": prompt,
1453+
"negative_prompt": negative_prompt,
1454+
"num_frames": num_frames,
1455+
"num_inference_steps": num_inference_steps,
1456+
"guidance_scale": guidance_scale,
1457+
"seed": seed,
1458+
"target_size": target_size,
1459+
**kwargs,
1460+
},
1461+
headers=self.headers,
1462+
model=model_id,
1463+
api_key=self.token,
1464+
)
1465+
response = await self._inner_post(request_parameters)
1466+
response = provider_helper.get_response(response, request_parameters)
1467+
return response
1468+
13881469
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
13891470
"""
13901471
Takes an input image and return text.

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .fal_ai import (
1414
FalAIAutomaticSpeechRecognitionTask,
1515
FalAIImageToImageTask,
16+
FalAIImageToVideoTask,
1617
FalAITextToImageTask,
1718
FalAITextToSpeechTask,
1819
FalAITextToVideoTask,
@@ -79,6 +80,7 @@
7980
"text-to-image": FalAITextToImageTask(),
8081
"text-to-speech": FalAITextToSpeechTask(),
8182
"text-to-video": FalAITextToVideoTask(),
83+
"image-to-video": FalAIImageToVideoTask(),
8284
"image-to-image": FalAIImageToImageTask(),
8385
},
8486
"featherless-ai": {

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,34 @@ def get_response(
213213
output = super().get_response(response, request_params)
214214
url = _as_dict(output)["images"][0]["url"]
215215
return get_session().get(url).content
216+
217+
218+
class FalAIImageToVideoTask(FalAIQueueTask):
219+
def __init__(self):
220+
super().__init__("image-to-video")
221+
222+
def _prepare_payload_as_dict(
223+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
224+
) -> Optional[Dict]:
225+
image_url = _as_url(inputs, default_mime_type="image/jpeg")
226+
payload: Dict[str, Any] = {
227+
"image_url": image_url,
228+
**filter_none(parameters),
229+
}
230+
if provider_mapping_info.adapter_weights_path is not None:
231+
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
232+
repo_id=provider_mapping_info.hf_model_id,
233+
revision="main",
234+
filename=provider_mapping_info.adapter_weights_path,
235+
)
236+
payload["loras"] = [{"path": lora_path, "scale": 1}]
237+
return payload
238+
239+
def get_response(
240+
self,
241+
response: Union[bytes, Dict],
242+
request_params: Optional[RequestParameters] = None,
243+
) -> Any:
244+
output = super().get_response(response, request_params)
245+
url = _as_dict(output)["video"]["url"]
246+
return get_session().get(url).content

tests/test_inference_providers.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_POLLING_INTERVAL,
2323
FalAIAutomaticSpeechRecognitionTask,
2424
FalAIImageToImageTask,
25+
FalAIImageToVideoTask,
2526
FalAITextToImageTask,
2627
FalAITextToSpeechTask,
2728
FalAITextToVideoTask,
@@ -476,6 +477,79 @@ def test_image_to_image_response(self, mocker):
476477
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
477478
assert response == b"image_content"
478479

480+
def test_image_to_video_payload(self):
481+
helper = FalAIImageToVideoTask()
482+
mapping_info = InferenceProviderMapping(
483+
provider="fal-ai",
484+
hf_model_id="Wan-AI/Wan2.2-I2V-A14B",
485+
providerId="Wan-AI/Wan2.2-I2V-A14B",
486+
task="image-to-video",
487+
status="live",
488+
)
489+
payload = helper._prepare_payload_as_dict(
490+
"https://example.com/image.png",
491+
{"prompt": "a cat"},
492+
mapping_info,
493+
)
494+
assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"}
495+
496+
payload = helper._prepare_payload_as_dict(
497+
b"dummy_image_data",
498+
{"prompt": "a dog"},
499+
mapping_info,
500+
)
501+
assert payload == {
502+
"image_url": f"data:image/jpeg;base64,{base64.b64encode(b'dummy_image_data').decode()}",
503+
"prompt": "a dog",
504+
}
505+
506+
def test_image_to_video_response(self, mocker):
507+
helper = FalAIImageToVideoTask()
508+
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
509+
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
510+
mock_session.return_value.get.side_effect = [
511+
# First call: status
512+
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
513+
# Second call: get result
514+
mocker.Mock(json=lambda: {"video": {"url": "video_url"}}, headers={"Content-Type": "application/json"}),
515+
# Third call: get video content
516+
mocker.Mock(content=b"video_content"),
517+
]
518+
api_key = helper._prepare_api_key("hf_token")
519+
headers = helper._prepare_headers({}, api_key)
520+
url = helper._prepare_url(api_key, "username/repo_name")
521+
522+
request_params = RequestParameters(
523+
url=url,
524+
headers=headers,
525+
task="image-to-video",
526+
model="username/repo_name",
527+
data=None,
528+
json=None,
529+
)
530+
response = helper.get_response(
531+
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
532+
request_params,
533+
)
534+
535+
# Verify the correct URLs were called
536+
assert mock_session.return_value.get.call_count == 3
537+
mock_session.return_value.get.assert_has_calls(
538+
[
539+
mocker.call(
540+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
541+
headers=request_params.headers,
542+
),
543+
mocker.call(
544+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
545+
headers=request_params.headers,
546+
),
547+
mocker.call("video_url"),
548+
]
549+
)
550+
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
551+
assert response == b"video_content"
552+
479553

480554
class TestFeatherlessAIProvider:
481555
def test_prepare_route_chat_completionurl(self):

0 commit comments

Comments
 (0)