Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
47498fb
Add `remote_decode` to `remote_utils`
hlky Feb 25, 2025
5fb50f1
test dependency
hlky Feb 25, 2025
76f79b3
test dependency
hlky Feb 25, 2025
303e920
dependency
hlky Feb 25, 2025
8c405d5
dependency
hlky Feb 25, 2025
19414df
dependency
hlky Feb 25, 2025
2a2157e
docstrings
hlky Feb 25, 2025
f80ef6d
changes
hlky Feb 25, 2025
2c572f7
make style
hlky Feb 25, 2025
1978a8a
apply
sayakpaul Feb 25, 2025
e55139b
revert, add new options
hlky Feb 28, 2025
4773420
Apply style fixes
github-actions[bot] Feb 28, 2025
54280dd
Merge branch 'main' into remote-utils
hlky Feb 28, 2025
2af1995
deprecate base64, headers not needed
hlky Feb 28, 2025
d80d66c
address comments
hlky Feb 28, 2025
05b39ab
add license header
hlky Feb 28, 2025
c2a2daf
init test_remote_decode
hlky Feb 28, 2025
1c4fdea
more
hlky Feb 28, 2025
f03a105
more test
hlky Feb 28, 2025
7e7af59
more test
hlky Feb 28, 2025
d16c855
skeleton for xl, flux
hlky Feb 28, 2025
485d99e
more test
hlky Mar 1, 2025
2937eb2
flux test
hlky Mar 1, 2025
86c2236
flux packed
hlky Mar 1, 2025
b10ea13
no scaling
hlky Mar 1, 2025
7df21f2
-save
hlky Mar 1, 2025
562a4c0
hunyuanvideo test
hlky Mar 1, 2025
9a39e35
Apply style fixes
github-actions[bot] Mar 1, 2025
217e161
Merge branch 'main' into remote-utils
hlky Mar 1, 2025
3712dc3
init docs
hlky Mar 1, 2025
5302645
Update src/diffusers/utils/remote_utils.py
hlky Mar 1, 2025
3f69f92
comments
hlky Mar 1, 2025
08ffc8f
Apply style fixes
github-actions[bot] Mar 1, 2025
6c2f123
comments
hlky Mar 2, 2025
82aa5cd
hybrid_inference/vae_decode
hlky Mar 2, 2025
7151510
fix
hlky Mar 2, 2025
9f6d91f
tip?
hlky Mar 2, 2025
4c24111
tip
hlky Mar 2, 2025
9c39564
api reference autodoc
hlky Mar 2, 2025
ca53835
install tip
hlky Mar 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
unscale_lora_layers,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
Expand Down
241 changes: 241 additions & 0 deletions src/diffusers/utils/remote_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import io
import json
from typing import List, Literal, Optional, Union, cast

import requests

from .deprecation_utils import deprecate
from .import_utils import is_safetensors_available, is_torch_available


if is_torch_available():
import torch

from ..image_processor import VaeImageProcessor
from ..video_processor import VideoProcessor

if is_safetensors_available():
import safetensors.torch

DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
}


from PIL import Image


def check_inputs(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
return_type: Literal["mp4", "pil", "pt"] = "pil",
image_format: Literal["png", "jpg"] = "jpg",
partial_postprocess: bool = False,
input_tensor_type: Literal["binary"] = "binary",
output_tensor_type: Literal["binary"] = "binary",
height: Optional[int] = None,
width: Optional[int] = None,
):
if tensor.ndim == 3 and height is None and width is None:
raise ValueError("`height` and `width` required for packed latents.")
if (
output_type == "pt"
and return_type == "pil"
and not partial_postprocess
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
):
raise ValueError("`processor` is required.")
if do_scaling and scaling_factor is None:
deprecate(
"do_scaling",
"1.0.0",
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
standard_warn=False,
)


def remote_decode(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
return_type: Literal["mp4", "pil", "pt"] = "pil",
image_format: Literal["png", "jpg"] = "jpg",
partial_postprocess: bool = False,
input_tensor_type: Literal["binary"] = "binary",
output_tensor_type: Literal["binary"] = "binary",
height: Optional[int] = None,
width: Optional[int] = None,
) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
"""
Args:
endpoint (`str`):
Endpoint for Remote Decode.
tensor (`torch.Tensor`):
Tensor to be decoded.
processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
Used with `return_type="pt"`, and `return_type="pil"` for Video models.
do_scaling (`bool`, default `True`, *optional*):
**DEPRECATED**. When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If
`False`, input must be passed with scaling applied.
scaling_factor (`float`, *optional*):
Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`. SD v1: 0.18215 SD XL:
0.13025 Flux: 0.3611
shift_factor (`float`, *optional*):
Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`. Flux: 0.1159
output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
**Endpoint** output type. Subject to change. Report feedback on preferred type.

`"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video
models.
Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns
`torch.Tensor` with partial `postprocessing` applied.
Requires `processor` as a flag (any `None` value will work).
`"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.

Recommendations:
`"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with
`partial_postprocess=False` is the most compatible with third party code. `"pil"` with
`image_format="jpg"` is the smallest transfer overall.

return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
**Function** return type.

`"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`.
With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is
created.
`partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is
**not** required.
`"pt"`: Function returns `torch.Tensor`.
`processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.

image_format (`"png"` or `"jpg"`, default `jpg`):
Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.

partial_postprocess (`bool`, default `False`):
Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.

input_tensor_type (`"binary"`, default `"binary"`):
Tensor transfer type.

output_tensor_type (`"binary"`, default `"binary"`):
Tensor transfer type.

height (`int`, **optional**):
Required for `"packed"` latents.

width (`int`, **optional**):
Required for `"packed"` latents.
"""
if input_tensor_type == "base64":
deprecate(
"input_tensor_type='base64'",
"1.0.0",
"input_tensor_type='base64' is deprecated. Using `binary`.",
standard_warn=False,
)
input_tensor_type = "binary"
if output_tensor_type == "base64":
deprecate(
"output_tensor_type='base64'",
"1.0.0",
"output_tensor_type='base64' is deprecated. Using `binary`.",
standard_warn=False,
)
output_tensor_type = "binary"
check_inputs(
endpoint,
tensor,
processor,
do_scaling,
scaling_factor,
shift_factor,
output_type,
return_type,
image_format,
partial_postprocess,
input_tensor_type,
output_tensor_type,
height,
width,
)
headers = {}
parameters = {
"image_format": image_format,
"output_type": output_type,
"partial_postprocess": partial_postprocess,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype).split(".")[-1],
}
if do_scaling and scaling_factor is not None:
parameters["scaling_factor"] = scaling_factor
if do_scaling and shift_factor is not None:
parameters["shift_factor"] = shift_factor
if do_scaling and scaling_factor is None:
parameters["do_scaling"] = do_scaling
elif do_scaling and scaling_factor is None and shift_factor is None:
parameters["do_scaling"] = do_scaling
if height is not None and width is not None:
parameters["height"] = height
parameters["width"] = width
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
kwargs = {"data": tensor_data}
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
if not response.ok:
raise RuntimeError(response.json())
if output_type == "pt" or (output_type == "pil" and processor is not None):
output_tensor = response.content
parameters = response.headers
shape = json.loads(parameters["shape"])
dtype = parameters["dtype"]
torch_dtype = DTYPE_MAP[dtype]
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
if output_type == "pt":
if partial_postprocess:
if return_type == "pil":
output = [Image.fromarray(image.numpy()) for image in output_tensor]
if len(output) == 1:
output = output[0]
elif return_type == "pt":
output = output_tensor
else:
if processor is None or return_type == "pt":
output = output_tensor
else:
if isinstance(processor, VideoProcessor):
output = cast(
List[Image.Image],
processor.postprocess_video(output_tensor, output_type="pil")[0],
)
else:
output = cast(
Image.Image,
processor.postprocess(output_tensor, output_type="pil")[0],
)
elif output_type == "pil" and return_type == "pil" and processor is None:
output = Image.open(io.BytesIO(response.content)).convert("RGB")
elif output_type == "pil" and processor is not None:
if return_type == "pil":
output = [
Image.fromarray(image)
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
]
elif return_type == "pt":
output = output_tensor
elif output_type == "mp4" and return_type == "mp4":
output = response.content
return output
Loading