-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add remote_decode to remote_utils
#10898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
47498fb
Add `remote_decode` to `remote_utils`
hlky 5fb50f1
test dependency
hlky 76f79b3
test dependency
hlky 303e920
dependency
hlky 8c405d5
dependency
hlky 19414df
dependency
hlky 2a2157e
docstrings
hlky f80ef6d
changes
hlky 2c572f7
make style
hlky 1978a8a
apply
sayakpaul e55139b
revert, add new options
hlky 4773420
Apply style fixes
github-actions[bot] 54280dd
Merge branch 'main' into remote-utils
hlky 2af1995
deprecate base64, headers not needed
hlky d80d66c
address comments
hlky 05b39ab
add license header
hlky c2a2daf
init test_remote_decode
hlky 1c4fdea
more
hlky f03a105
more test
hlky 7e7af59
more test
hlky d16c855
skeleton for xl, flux
hlky 485d99e
more test
hlky 2937eb2
flux test
hlky 86c2236
flux packed
hlky b10ea13
no scaling
hlky 7df21f2
-save
hlky 562a4c0
hunyuanvideo test
hlky 9a39e35
Apply style fixes
github-actions[bot] 217e161
Merge branch 'main' into remote-utils
hlky 3712dc3
init docs
hlky 5302645
Update src/diffusers/utils/remote_utils.py
hlky 3f69f92
comments
hlky 08ffc8f
Apply style fixes
github-actions[bot] 6c2f123
comments
hlky 82aa5cd
hybrid_inference/vae_decode
hlky 7151510
fix
hlky 9f6d91f
tip?
hlky 4c24111
tip
hlky 9c39564
api reference autodoc
hlky ca53835
install tip
hlky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,241 @@ | ||
| import io | ||
| import json | ||
| from typing import List, Literal, Optional, Union, cast | ||
|
|
||
| import requests | ||
|
|
||
| from .deprecation_utils import deprecate | ||
| from .import_utils import is_safetensors_available, is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| from ..image_processor import VaeImageProcessor | ||
| from ..video_processor import VideoProcessor | ||
|
|
||
| if is_safetensors_available(): | ||
| import safetensors.torch | ||
|
|
||
| DTYPE_MAP = { | ||
| "float16": torch.float16, | ||
| "float32": torch.float32, | ||
| "bfloat16": torch.bfloat16, | ||
| "uint8": torch.uint8, | ||
| } | ||
|
|
||
|
|
||
| from PIL import Image | ||
|
|
||
|
|
||
| def check_inputs( | ||
| endpoint: str, | ||
| tensor: "torch.Tensor", | ||
| processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, | ||
| do_scaling: bool = True, | ||
| scaling_factor: Optional[float] = None, | ||
| shift_factor: Optional[float] = None, | ||
| output_type: Literal["mp4", "pil", "pt"] = "pil", | ||
| return_type: Literal["mp4", "pil", "pt"] = "pil", | ||
| image_format: Literal["png", "jpg"] = "jpg", | ||
| partial_postprocess: bool = False, | ||
| input_tensor_type: Literal["binary"] = "binary", | ||
| output_tensor_type: Literal["binary"] = "binary", | ||
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| ): | ||
| if tensor.ndim == 3 and height is None and width is None: | ||
| raise ValueError("`height` and `width` required for packed latents.") | ||
| if ( | ||
| output_type == "pt" | ||
| and return_type == "pil" | ||
| and not partial_postprocess | ||
| and not isinstance(processor, (VaeImageProcessor, VideoProcessor)) | ||
| ): | ||
| raise ValueError("`processor` is required.") | ||
| if do_scaling and scaling_factor is None: | ||
| deprecate( | ||
| "do_scaling", | ||
| "1.0.0", | ||
| "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", | ||
| standard_warn=False, | ||
| ) | ||
|
|
||
|
|
||
| def remote_decode( | ||
| endpoint: str, | ||
| tensor: "torch.Tensor", | ||
| processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, | ||
| do_scaling: bool = True, | ||
| scaling_factor: Optional[float] = None, | ||
| shift_factor: Optional[float] = None, | ||
| output_type: Literal["mp4", "pil", "pt"] = "pil", | ||
| return_type: Literal["mp4", "pil", "pt"] = "pil", | ||
| image_format: Literal["png", "jpg"] = "jpg", | ||
| partial_postprocess: bool = False, | ||
| input_tensor_type: Literal["binary"] = "binary", | ||
| output_tensor_type: Literal["binary"] = "binary", | ||
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| ) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]: | ||
| """ | ||
| Args: | ||
| endpoint (`str`): | ||
| Endpoint for Remote Decode. | ||
| tensor (`torch.Tensor`): | ||
| Tensor to be decoded. | ||
| processor (`VaeImageProcessor` or `VideoProcessor`, *optional*): | ||
| Used with `return_type="pt"`, and `return_type="pil"` for Video models. | ||
| do_scaling (`bool`, default `True`, *optional*): | ||
| **DEPRECATED**. When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If | ||
| `False`, input must be passed with scaling applied. | ||
| scaling_factor (`float`, *optional*): | ||
| Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`. SD v1: 0.18215 SD XL: | ||
| 0.13025 Flux: 0.3611 | ||
| shift_factor (`float`, *optional*): | ||
| Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`. Flux: 0.1159 | ||
| output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): | ||
| **Endpoint** output type. Subject to change. Report feedback on preferred type. | ||
|
|
||
| `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video | ||
| models. | ||
| Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns | ||
| `torch.Tensor` with partial `postprocessing` applied. | ||
| Requires `processor` as a flag (any `None` value will work). | ||
| `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`. | ||
| With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor. | ||
|
|
||
| Recommendations: | ||
| `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with | ||
| `partial_postprocess=False` is the most compatible with third party code. `"pil"` with | ||
| `image_format="jpg"` is the smallest transfer overall. | ||
|
|
||
| return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"): | ||
| **Function** return type. | ||
|
|
||
| `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`. | ||
| With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is | ||
| created. | ||
| `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is | ||
| **not** required. | ||
| `"pt"`: Function returns `torch.Tensor`. | ||
| `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without | ||
| denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. | ||
|
|
||
| image_format (`"png"` or `"jpg"`, default `jpg`): | ||
| Used with `output_type="pil"`. Endpoint returns `jpg` or `png`. | ||
|
|
||
| partial_postprocess (`bool`, default `False`): | ||
| Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without | ||
| denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized. | ||
|
|
||
| input_tensor_type (`"binary"`, default `"binary"`): | ||
| Tensor transfer type. | ||
|
|
||
| output_tensor_type (`"binary"`, default `"binary"`): | ||
| Tensor transfer type. | ||
|
|
||
| height (`int`, **optional**): | ||
| Required for `"packed"` latents. | ||
|
|
||
| width (`int`, **optional**): | ||
| Required for `"packed"` latents. | ||
| """ | ||
| if input_tensor_type == "base64": | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| deprecate( | ||
| "input_tensor_type='base64'", | ||
| "1.0.0", | ||
| "input_tensor_type='base64' is deprecated. Using `binary`.", | ||
| standard_warn=False, | ||
| ) | ||
| input_tensor_type = "binary" | ||
| if output_tensor_type == "base64": | ||
| deprecate( | ||
| "output_tensor_type='base64'", | ||
| "1.0.0", | ||
| "output_tensor_type='base64' is deprecated. Using `binary`.", | ||
| standard_warn=False, | ||
| ) | ||
| output_tensor_type = "binary" | ||
| check_inputs( | ||
| endpoint, | ||
| tensor, | ||
| processor, | ||
| do_scaling, | ||
| scaling_factor, | ||
| shift_factor, | ||
| output_type, | ||
| return_type, | ||
| image_format, | ||
| partial_postprocess, | ||
| input_tensor_type, | ||
| output_tensor_type, | ||
| height, | ||
| width, | ||
| ) | ||
| headers = {} | ||
| parameters = { | ||
| "image_format": image_format, | ||
| "output_type": output_type, | ||
| "partial_postprocess": partial_postprocess, | ||
| "shape": list(tensor.shape), | ||
| "dtype": str(tensor.dtype).split(".")[-1], | ||
| } | ||
| if do_scaling and scaling_factor is not None: | ||
| parameters["scaling_factor"] = scaling_factor | ||
| if do_scaling and shift_factor is not None: | ||
| parameters["shift_factor"] = shift_factor | ||
| if do_scaling and scaling_factor is None: | ||
| parameters["do_scaling"] = do_scaling | ||
| elif do_scaling and scaling_factor is None and shift_factor is None: | ||
| parameters["do_scaling"] = do_scaling | ||
| if height is not None and width is not None: | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| parameters["height"] = height | ||
| parameters["width"] = width | ||
| tensor_data = safetensors.torch._tobytes(tensor, "tensor") | ||
| kwargs = {"data": tensor_data} | ||
| response = requests.post(endpoint, params=parameters, **kwargs, headers=headers) | ||
| if not response.ok: | ||
| raise RuntimeError(response.json()) | ||
| if output_type == "pt" or (output_type == "pil" and processor is not None): | ||
| output_tensor = response.content | ||
| parameters = response.headers | ||
| shape = json.loads(parameters["shape"]) | ||
| dtype = parameters["dtype"] | ||
| torch_dtype = DTYPE_MAP[dtype] | ||
| output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) | ||
| if output_type == "pt": | ||
| if partial_postprocess: | ||
| if return_type == "pil": | ||
| output = [Image.fromarray(image.numpy()) for image in output_tensor] | ||
| if len(output) == 1: | ||
| output = output[0] | ||
| elif return_type == "pt": | ||
| output = output_tensor | ||
| else: | ||
| if processor is None or return_type == "pt": | ||
| output = output_tensor | ||
| else: | ||
| if isinstance(processor, VideoProcessor): | ||
| output = cast( | ||
| List[Image.Image], | ||
| processor.postprocess_video(output_tensor, output_type="pil")[0], | ||
| ) | ||
| else: | ||
| output = cast( | ||
| Image.Image, | ||
| processor.postprocess(output_tensor, output_type="pil")[0], | ||
| ) | ||
| elif output_type == "pil" and return_type == "pil" and processor is None: | ||
| output = Image.open(io.BytesIO(response.content)).convert("RGB") | ||
| elif output_type == "pil" and processor is not None: | ||
| if return_type == "pil": | ||
| output = [ | ||
| Image.fromarray(image) | ||
| for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8") | ||
| ] | ||
| elif return_type == "pt": | ||
| output = output_tensor | ||
| elif output_type == "mp4" and return_type == "mp4": | ||
| output = response.content | ||
| return output | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.