diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9f76be91339a..faedc68fc430 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -76,6 +76,14 @@ - local: advanced_inference/outpaint title: Outpainting title: Advanced inference +- sections: + - local: hybrid_inference/overview + title: Overview + - local: hybrid_inference/vae_decode + title: VAE Decode + - local: hybrid_inference/api_reference + title: API Reference + title: Hybrid Inference - sections: - local: using-diffusers/cogvideox title: CogVideoX diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md new file mode 100644 index 000000000000..aa0a5e5ae58f --- /dev/null +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -0,0 +1,5 @@ +# Hybrid Inference API Reference + +## Remote Decode + +[[autodoc]] utils.remote_utils.remote_decode diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md new file mode 100644 index 000000000000..9bbe245901df --- /dev/null +++ b/docs/source/en/hybrid_inference/overview.md @@ -0,0 +1,54 @@ + + +# Hybrid Inference + +**Empowering local AI builders with Hybrid Inference** + + +> [!TIP] +> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae). +> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). + + + +## Why use Hybrid Inference? + +Hybrid Inference offers a fast and simple way to offload local generation requirements. + +- πŸš€ **Reduced Requirements:** Access powerful models without expensive hardware. +- πŸ’Ž **Without Compromise:** Achieve the highest quality without sacrificing performance. +- πŸ’° **Cost Effective:** It's free! πŸ€‘ +- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community. +- πŸ”§ **Developer-Friendly:** Simple requests, fast responses. + +--- + +## Available Models + +* **VAE Decode πŸ–ΌοΈ:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. +* **VAE Encode πŸ”’ (coming soon):** Efficiently encode images into latent representations for generation and training. +* **Text Encoders πŸ“ƒ (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. + +--- + +## Integrations + +* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. +* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. + +## Contents + +The documentation is organized into two sections: + +* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. +* **API Reference** Dive into task-specific settings and parameters. diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md new file mode 100644 index 000000000000..1457090550c7 --- /dev/null +++ b/docs/source/en/hybrid_inference/vae_decode.md @@ -0,0 +1,345 @@ +# Getting Started: VAE Decode with Hybrid Inference + +VAE decode is an essential component of diffusion models - turning latent representations into images or videos. + +## Memory + +These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs. + +For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality. + +
SD v1.5 + +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% | +| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% | +| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% | +| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% | +| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% | + +
+ +
SDXL + +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% | +| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% | +| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% | +| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% | +| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% | + +
+ +## Available VAEs + +| | **Endpoint** | **Model** | +|:-:|:-----------:|:--------:| +| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | +| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | +| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | +| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) | + + +> [!TIP] +> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). + + +## Code + +> [!TIP] +> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` + + +A helper method simplifies interacting with Hybrid Inference. + +```python +from diffusers.utils.remote_utils import remote_decode +``` + +### Basic example + +Here, we show how to use the remote VAE on random tensors. + +
Code + +```python +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16), + scaling_factor=0.18215, +) +``` + +
+ +
+ +
+ +Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`. + +
Code + +```python +image = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 4096, 64], dtype=torch.float16), + height=1024, + width=1024, + scaling_factor=0.3611, + shift_factor=0.1159, +) +``` + +
+ +
+ +
+ +Finally, an example for HunyuanVideo. + +
Code + +```python +video = remote_decode( + endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16), + output_type="mp4", +) +with open("video.mp4", "wb") as f: + f.write(video) +``` + +
+ +
+ +
+ + +### Generation + +But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5. + +
Code + +```python +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + vae=None, +).to("cuda") + +prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" + +latent = pipe( + prompt=prompt, + output_type="latent", +).images +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.18215, +) +image.save("test.jpg") +``` + +
+ +
+ +
+ +Here’s another example with Flux. + +
Code + +```python +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, + vae=None, +).to("cuda") + +prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" + +latent = pipe( + prompt=prompt, + guidance_scale=0.0, + num_inference_steps=4, + output_type="latent", +).images +image = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + height=1024, + width=1024, + scaling_factor=0.3611, + shift_factor=0.1159, +) +image.save("test.jpg") +``` + +
+ +
+ +
+ +Here’s an example with HunyuanVideo. + +
Code + +```python +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + +model_id = "hunyuanvideo-community/HunyuanVideo" +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=torch.bfloat16 +) +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, transformer=transformer, vae=None, torch_dtype=torch.float16 +).to("cuda") + +latent = pipe( + prompt="A cat walks on the grass, realistic", + height=320, + width=512, + num_frames=61, + num_inference_steps=30, + output_type="latent", +).frames + +video = remote_decode( + endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + output_type="mp4", +) + +if isinstance(video, bytes): + with open("video.mp4", "wb") as f: + f.write(video) +``` + +
+ +
+ +
+ + +### Queueing + +One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency. + + +
Code + +```python +import queue +import threading +from IPython.display import display +from diffusers import StableDiffusionPipeline + +def decode_worker(q: queue.Queue): + while True: + item = q.get() + if item is None: + break + image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=item, + scaling_factor=0.18215, + ) + display(image) + q.task_done() + +q = queue.Queue() +thread = threading.Thread(target=decode_worker, args=(q,), daemon=True) +thread.start() + +def decode(latent: torch.Tensor): + q.put(latent) + +prompts = [ + "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious", + "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore", + "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.", + "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP", + "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting", + "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,", +] + +pipe = StableDiffusionPipeline.from_pretrained( + "Lykon/dreamshaper-8", + torch_dtype=torch.float16, + vae=None, +).to("cuda") + +pipe.unet = pipe.unet.to(memory_format=torch.channels_last) +pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +_ = pipe( + prompt=prompts[0], + output_type="latent", +) + +for prompt in prompts: + latent = pipe( + prompt=prompt, + output_type="latent", + ).images + decode(latent) + +q.put(None) +thread.join() +``` + +
+ + +
+ +
+ +## Integrations + +* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. +* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 08b1713d0e31..6702ea2efbc8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -116,6 +116,7 @@ unscale_lora_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil +from .remote_utils import remote_decode from .state_dict_utils import ( convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py new file mode 100644 index 000000000000..12bcc94af74f --- /dev/null +++ b/src/diffusers/utils/remote_utils.py @@ -0,0 +1,334 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +from typing import List, Literal, Optional, Union, cast + +import requests + +from .deprecation_utils import deprecate +from .import_utils import is_safetensors_available, is_torch_available + + +if is_torch_available(): + import torch + + from ..image_processor import VaeImageProcessor + from ..video_processor import VideoProcessor + + if is_safetensors_available(): + import safetensors.torch + + DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "uint8": torch.uint8, + } + + +from PIL import Image + + +def detect_image_type(data: bytes) -> str: + if data.startswith(b"\xff\xd8"): + return "jpeg" + elif data.startswith(b"\x89PNG\r\n\x1a\n"): + return "png" + elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"): + return "gif" + elif data.startswith(b"BM"): + return "bmp" + return "unknown" + + +def check_inputs( + endpoint: str, + tensor: "torch.Tensor", + processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, + do_scaling: bool = True, + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, + output_type: Literal["mp4", "pil", "pt"] = "pil", + return_type: Literal["mp4", "pil", "pt"] = "pil", + image_format: Literal["png", "jpg"] = "jpg", + partial_postprocess: bool = False, + input_tensor_type: Literal["binary"] = "binary", + output_tensor_type: Literal["binary"] = "binary", + height: Optional[int] = None, + width: Optional[int] = None, +): + if tensor.ndim == 3 and height is None and width is None: + raise ValueError("`height` and `width` required for packed latents.") + if ( + output_type == "pt" + and return_type == "pil" + and not partial_postprocess + and not isinstance(processor, (VaeImageProcessor, VideoProcessor)) + ): + raise ValueError("`processor` is required.") + if do_scaling and scaling_factor is None: + deprecate( + "do_scaling", + "1.0.0", + "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", + standard_warn=False, + ) + + +def postprocess( + response: requests.Response, + processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, + output_type: Literal["mp4", "pil", "pt"] = "pil", + return_type: Literal["mp4", "pil", "pt"] = "pil", + partial_postprocess: bool = False, +): + if output_type == "pt" or (output_type == "pil" and processor is not None): + output_tensor = response.content + parameters = response.headers + shape = json.loads(parameters["shape"]) + dtype = parameters["dtype"] + torch_dtype = DTYPE_MAP[dtype] + output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) + if output_type == "pt": + if partial_postprocess: + if return_type == "pil": + output = [Image.fromarray(image.numpy()) for image in output_tensor] + if len(output) == 1: + output = output[0] + elif return_type == "pt": + output = output_tensor + else: + if processor is None or return_type == "pt": + output = output_tensor + else: + if isinstance(processor, VideoProcessor): + output = cast( + List[Image.Image], + processor.postprocess_video(output_tensor, output_type="pil")[0], + ) + else: + output = cast( + Image.Image, + processor.postprocess(output_tensor, output_type="pil")[0], + ) + elif output_type == "pil" and return_type == "pil" and processor is None: + output = Image.open(io.BytesIO(response.content)).convert("RGB") + detected_format = detect_image_type(response.content) + output.format = detected_format + elif output_type == "pil" and processor is not None: + if return_type == "pil": + output = [ + Image.fromarray(image) + for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8") + ] + elif return_type == "pt": + output = output_tensor + elif output_type == "mp4" and return_type == "mp4": + output = response.content + return output + + +def prepare( + tensor: "torch.Tensor", + processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, + do_scaling: bool = True, + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, + output_type: Literal["mp4", "pil", "pt"] = "pil", + image_format: Literal["png", "jpg"] = "jpg", + partial_postprocess: bool = False, + height: Optional[int] = None, + width: Optional[int] = None, +): + headers = {} + parameters = { + "image_format": image_format, + "output_type": output_type, + "partial_postprocess": partial_postprocess, + "shape": list(tensor.shape), + "dtype": str(tensor.dtype).split(".")[-1], + } + if do_scaling and scaling_factor is not None: + parameters["scaling_factor"] = scaling_factor + if do_scaling and shift_factor is not None: + parameters["shift_factor"] = shift_factor + if do_scaling and scaling_factor is None: + parameters["do_scaling"] = do_scaling + elif do_scaling and scaling_factor is None and shift_factor is None: + parameters["do_scaling"] = do_scaling + if height is not None and width is not None: + parameters["height"] = height + parameters["width"] = width + headers["Content-Type"] = "tensor/binary" + headers["Accept"] = "tensor/binary" + if output_type == "pil" and image_format == "jpg" and processor is None: + headers["Accept"] = "image/jpeg" + elif output_type == "pil" and image_format == "png" and processor is None: + headers["Accept"] = "image/png" + elif output_type == "mp4": + headers["Accept"] = "text/plain" + tensor_data = safetensors.torch._tobytes(tensor, "tensor") + return {"data": tensor_data, "params": parameters, "headers": headers} + + +def remote_decode( + endpoint: str, + tensor: "torch.Tensor", + processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, + do_scaling: bool = True, + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, + output_type: Literal["mp4", "pil", "pt"] = "pil", + return_type: Literal["mp4", "pil", "pt"] = "pil", + image_format: Literal["png", "jpg"] = "jpg", + partial_postprocess: bool = False, + input_tensor_type: Literal["binary"] = "binary", + output_tensor_type: Literal["binary"] = "binary", + height: Optional[int] = None, + width: Optional[int] = None, +) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]: + """ + Hugging Face Hybrid Inference that allow running VAE decode remotely. + + Args: + endpoint (`str`): + Endpoint for Remote Decode. + tensor (`torch.Tensor`): + Tensor to be decoded. + processor (`VaeImageProcessor` or `VideoProcessor`, *optional*): + Used with `return_type="pt"`, and `return_type="pil"` for Video models. + do_scaling (`bool`, default `True`, *optional*): + **DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set + do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents + / self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling + applied. + scaling_factor (`float`, *optional*): + Scaling is applied when passed e.g. [`latents / + self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77). + - SD v1: 0.18215 + - SD XL: 0.13025 + - Flux: 0.3611 + If `None`, input must be passed with scaling applied. + shift_factor (`float`, *optional*): + Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`. + - Flux: 0.1159 + If `None`, input must be passed with scaling applied. + output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): + **Endpoint** output type. Subject to change. Report feedback on preferred type. + + `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video + models. + Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns + `torch.Tensor` with partial `postprocessing` applied. + Requires `processor` as a flag (any `None` value will work). + `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`. + With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor. + + Recommendations: + `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with + `partial_postprocess=False` is the most compatible with third party code. `"pil"` with + `image_format="jpg"` is the smallest transfer overall. + + return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): + **Function** return type. + + `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`. + With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is + created. + `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is + **not** required. + `"pt"`: Function returns `torch.Tensor`. + `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without + denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. + + image_format (`"png"` or `"jpg"`, default `jpg`): + Used with `output_type="pil"`. Endpoint returns `jpg` or `png`. + + partial_postprocess (`bool`, default `False`): + Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without + denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. + + input_tensor_type (`"binary"`, default `"binary"`): + Tensor transfer type. + + output_tensor_type (`"binary"`, default `"binary"`): + Tensor transfer type. + + height (`int`, **optional**): + Required for `"packed"` latents. + + width (`int`, **optional**): + Required for `"packed"` latents. + + Returns: + output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`). + """ + if input_tensor_type == "base64": + deprecate( + "input_tensor_type='base64'", + "1.0.0", + "input_tensor_type='base64' is deprecated. Using `binary`.", + standard_warn=False, + ) + input_tensor_type = "binary" + if output_tensor_type == "base64": + deprecate( + "output_tensor_type='base64'", + "1.0.0", + "output_tensor_type='base64' is deprecated. Using `binary`.", + standard_warn=False, + ) + output_tensor_type = "binary" + check_inputs( + endpoint, + tensor, + processor, + do_scaling, + scaling_factor, + shift_factor, + output_type, + return_type, + image_format, + partial_postprocess, + input_tensor_type, + output_tensor_type, + height, + width, + ) + kwargs = prepare( + tensor=tensor, + processor=processor, + do_scaling=do_scaling, + scaling_factor=scaling_factor, + shift_factor=shift_factor, + output_type=output_type, + image_format=image_format, + partial_postprocess=partial_postprocess, + height=height, + width=width, + ) + response = requests.post(endpoint, **kwargs) + if not response.ok: + raise RuntimeError(response.json()) + output = postprocess( + response=response, + processor=processor, + output_type=output_type, + return_type=return_type, + partial_postprocess=partial_postprocess, + ) + return output diff --git a/tests/remote/__init__.py b/tests/remote/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py new file mode 100644 index 000000000000..d8e7baafb7f8 --- /dev/null +++ b/tests/remote/test_remote_decode.py @@ -0,0 +1,458 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.remote_utils import remote_decode +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_all_close, + torch_device, +) +from diffusers.video_processor import VideoProcessor + + +enable_full_determinism() + + +class RemoteAutoencoderKLMixin: + shape: Tuple[int, ...] = None + out_hw: Tuple[int, int] = None + endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + processor_cls: Union[VaeImageProcessor, VideoProcessor] = None + output_pil_slice: torch.Tensor = None + output_pt_slice: torch.Tensor = None + partial_postprocess_return_pt_slice: torch.Tensor = None + return_pt_slice: torch.Tensor = None + width: int = None + height: int = None + + def get_dummy_inputs(self): + inputs = { + "endpoint": self.endpoint, + "tensor": torch.randn( + self.shape, + device=torch_device, + dtype=self.dtype, + generator=torch.Generator(torch_device).manual_seed(13), + ), + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + "height": self.height, + "width": self.width, + } + return inputs + + def test_no_scaling(self): + inputs = self.get_dummy_inputs() + if inputs["scaling_factor"] is not None: + inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"] + inputs["scaling_factor"] = None + if inputs["shift_factor"] is not None: + inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"] + inputs["shift_factor"] = None + processor = self.processor_cls() + output = remote_decode( + output_type="pt", + # required for now, will be removed in next update + do_scaling=False, + processor=processor, + **inputs, + ) + assert isinstance(output, PIL.Image.Image) + self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") + self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") + self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") + output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) + # Increased tolerance for Flux Packed diff [1, 0, 1, 0, 0, 0, 0, 0, 0] + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), + f"{output_slice}", + ) + + def test_output_type_pt(self): + inputs = self.get_dummy_inputs() + processor = self.processor_cls() + output = remote_decode(output_type="pt", processor=processor, **inputs) + assert isinstance(output, PIL.Image.Image) + self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") + self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") + self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") + output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" + ) + + # output is visually the same, slice is flaky? + def test_output_type_pil(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pil", **inputs) + self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") + self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") + self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") + + def test_output_type_pil_image_format(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pil", image_format="png", **inputs) + self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") + self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") + self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") + self.assertEqual(output.format, "png", f"Expected image format `png`, got {output.format}") + output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" + ) + + def test_output_type_pt_partial_postprocess(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) + self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}") + self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}") + self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}") + output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}" + ) + + def test_output_type_pt_return_type_pt(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pt", return_type="pt", **inputs) + self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") + self.assertEqual( + output.shape[2], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}" + ) + self.assertEqual( + output.shape[3], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}" + ) + output_slice = output[0, 0, -3:, -3:].flatten() + self.assertTrue( + torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3), + f"{output_slice}", + ) + + def test_output_type_pt_partial_postprocess_return_type_pt(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pt", partial_postprocess=True, return_type="pt", **inputs) + self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") + self.assertEqual( + output.shape[1], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[1]}" + ) + self.assertEqual( + output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}" + ) + output_slice = output[0, -3:, -3:, 0].flatten().cpu() + self.assertTrue( + torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2), + f"{output_slice}", + ) + + def test_do_scaling_deprecation(self): + inputs = self.get_dummy_inputs() + inputs.pop("scaling_factor", None) + inputs.pop("shift_factor", None) + with self.assertWarns(FutureWarning) as warning: + _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) + self.assertEqual( + str(warning.warnings[0].message), + "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", + str(warning.warnings[0].message), + ) + + def test_input_tensor_type_base64_deprecation(self): + inputs = self.get_dummy_inputs() + with self.assertWarns(FutureWarning) as warning: + _ = remote_decode(output_type="pt", input_tensor_type="base64", partial_postprocess=True, **inputs) + self.assertEqual( + str(warning.warnings[0].message), + "input_tensor_type='base64' is deprecated. Using `binary`.", + str(warning.warnings[0].message), + ) + + def test_output_tensor_type_base64_deprecation(self): + inputs = self.get_dummy_inputs() + with self.assertWarns(FutureWarning) as warning: + _ = remote_decode(output_type="pt", output_tensor_type="base64", partial_postprocess=True, **inputs) + self.assertEqual( + str(warning.warnings[0].message), + "output_tensor_type='base64' is deprecated. Using `binary`.", + str(warning.warnings[0].message), + ) + + +class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): + def test_no_scaling(self): + inputs = self.get_dummy_inputs() + if inputs["scaling_factor"] is not None: + inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"] + inputs["scaling_factor"] = None + if inputs["shift_factor"] is not None: + inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"] + inputs["shift_factor"] = None + processor = self.processor_cls() + output = remote_decode( + output_type="pt", + # required for now, will be removed in next update + do_scaling=False, + processor=processor, + **inputs, + ) + self.assertTrue( + isinstance(output, list) and isinstance(output[0], PIL.Image.Image), + f"Expected `List[PIL.Image.Image]` output, got {type(output)}", + ) + self.assertEqual( + output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" + ) + self.assertEqual( + output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" + ) + output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), + f"{output_slice}", + ) + + def test_output_type_pt(self): + inputs = self.get_dummy_inputs() + processor = self.processor_cls() + output = remote_decode(output_type="pt", processor=processor, **inputs) + self.assertTrue( + isinstance(output, list) and isinstance(output[0], PIL.Image.Image), + f"Expected `List[PIL.Image.Image]` output, got {type(output)}", + ) + self.assertEqual( + output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" + ) + self.assertEqual( + output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" + ) + output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), + f"{output_slice}", + ) + + # output is visually the same, slice is flaky? + def test_output_type_pil(self): + inputs = self.get_dummy_inputs() + processor = self.processor_cls() + output = remote_decode(output_type="pil", processor=processor, **inputs) + self.assertTrue( + isinstance(output, list) and isinstance(output[0], PIL.Image.Image), + f"Expected `List[PIL.Image.Image]` output, got {type(output)}", + ) + self.assertEqual( + output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" + ) + self.assertEqual( + output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" + ) + + def test_output_type_pil_image_format(self): + inputs = self.get_dummy_inputs() + processor = self.processor_cls() + output = remote_decode(output_type="pil", processor=processor, image_format="png", **inputs) + self.assertTrue( + isinstance(output, list) and isinstance(output[0], PIL.Image.Image), + f"Expected `List[PIL.Image.Image]` output, got {type(output)}", + ) + self.assertEqual( + output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" + ) + self.assertEqual( + output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" + ) + output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), + f"{output_slice}", + ) + + def test_output_type_pt_partial_postprocess(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) + self.assertTrue( + isinstance(output, list) and isinstance(output[0], PIL.Image.Image), + f"Expected `List[PIL.Image.Image]` output, got {type(output)}", + ) + self.assertEqual( + output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}" + ) + self.assertEqual( + output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}" + ) + output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten()) + self.assertTrue( + torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1), + f"{output_slice}", + ) + + def test_output_type_pt_return_type_pt(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="pt", return_type="pt", **inputs) + self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}") + self.assertEqual( + output.shape[3], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}" + ) + self.assertEqual( + output.shape[4], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}" + ) + output_slice = output[0, 0, 0, -3:, -3:].flatten() + self.assertTrue( + torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3), + f"{output_slice}", + ) + + def test_output_type_mp4(self): + inputs = self.get_dummy_inputs() + output = remote_decode(output_type="mp4", return_type="mp4", **inputs) + self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") + + +class RemoteAutoencoderKLSDv1Tests( + RemoteAutoencoderKLMixin, + unittest.TestCase, +): + shape = ( + 1, + 4, + 64, + 64, + ) + out_hw = ( + 512, + 512, + ) + endpoint = "https://bz0b3zkoojf30bhx.us-east-1.aws.endpoints.huggingface.cloud/" + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + processor_cls = VaeImageProcessor + output_pt_slice = torch.tensor([31, 15, 11, 55, 30, 21, 66, 42, 30], dtype=torch.uint8) + partial_postprocess_return_pt_slice = torch.tensor([100, 130, 99, 133, 106, 112, 97, 100, 121], dtype=torch.uint8) + return_pt_slice = torch.tensor([-0.2177, 0.0217, -0.2258, 0.0412, -0.1687, -0.1232, -0.2416, -0.2130, -0.0543]) + + +# class RemoteAutoencoderKLSDXLTests( +# RemoteAutoencoderKLMixin, +# unittest.TestCase, +# ): +# shape = ( +# 1, +# 4, +# 128, +# 128, +# ) +# out_hw = ( +# 1024, +# 1024, +# ) +# endpoint = "https://fagf07t3bwf0615i.us-east-1.aws.endpoints.huggingface.cloud/" +# dtype = torch.float16 +# scaling_factor = 0.13025 +# shift_factor = None +# processor_cls = VaeImageProcessor +# output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8) +# partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8) +# return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845]) + + +# class RemoteAutoencoderKLFluxTests( +# RemoteAutoencoderKLMixin, +# unittest.TestCase, +# ): +# shape = ( +# 1, +# 16, +# 128, +# 128, +# ) +# out_hw = ( +# 1024, +# 1024, +# ) +# endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/" +# dtype = torch.bfloat16 +# scaling_factor = 0.3611 +# shift_factor = 0.1159 +# processor_cls = VaeImageProcessor +# output_pt_slice = torch.tensor([110, 72, 91, 62, 35, 52, 69, 55, 69], dtype=torch.uint8) +# partial_postprocess_return_pt_slice = torch.tensor( +# [202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8 +# ) +# return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984]) + + +# class RemoteAutoencoderKLFluxPackedTests( +# RemoteAutoencoderKLMixin, +# unittest.TestCase, +# ): +# shape = ( +# 1, +# 4096, +# 64, +# ) +# out_hw = ( +# 1024, +# 1024, +# ) +# height = 1024 +# width = 1024 +# endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/" +# dtype = torch.bfloat16 +# scaling_factor = 0.3611 +# shift_factor = 0.1159 +# processor_cls = VaeImageProcessor +# # slices are different due to randn on different shape. we can pack the latent instead if we want the same +# output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8) +# partial_postprocess_return_pt_slice = torch.tensor( +# [168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8 +# ) +# return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176]) + + +# class RemoteAutoencoderKLHunyuanVideoTests( +# RemoteAutoencoderKLHunyuanVideoMixin, +# unittest.TestCase, +# ): +# shape = ( +# 1, +# 16, +# 3, +# 40, +# 64, +# ) +# out_hw = ( +# 320, +# 512, +# ) +# endpoint = "https://lsx2injm3ts8wbvv.us-east-1.aws.endpoints.huggingface.cloud/" +# dtype = torch.float16 +# scaling_factor = 0.476986 +# processor_cls = VideoProcessor +# output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8) +# partial_postprocess_return_pt_slice = torch.tensor( +# [149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8 +# ) +# return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])