Skip to content

Commit 689c8d1

Browse files
authored
Update Remote VAE blog (#2714)
* Update Remote VAE blog * Update remote_vae.md
1 parent aabc5e3 commit 689c8d1

File tree

1 file changed

+11
-193
lines changed

1 file changed

+11
-193
lines changed

remote_vae.md

Lines changed: 11 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Therefore, we want to pilot an idea with the community — delegating the decodi
1616

1717
No data is stored or tracked, and code is open source. We made some changes to [huggingface-inference-toolkit](https://github.com/hlky/huggingface-inference-toolkit/tree/fix-text-support-binary) and use [custom handlers](https://huggingface.co/hlky/sd-vae-ft-mse/blob/main/handler.py).
1818

19+
This experimental feature is developed by [Diffusers 🧨](https://huggingface.co/docs/diffusers/hybrid_inference/overview)
20+
1921
**Table of contents**:
2022

2123
- [Getting started](#getting-started)
@@ -37,141 +39,14 @@ Below, we cover three use cases where we think this remote VAE inference would b
3739
First, we have created a helper method for interacting with Remote VAEs.
3840

3941
> [!NOTE]
40-
> We recommend installing `diffusers` from `main` to run the code.
42+
> Install `diffusers` from `main` to run the code.
4143
> `pip install git+https://github.com/huggingface/diffusers@main`
4244
4345
<details><summary>Code</summary>
4446
<p>
4547

4648
```python
47-
from typing import cast, List, Literal, Optional, Union
48-
49-
import base64
50-
import io
51-
import json
52-
import requests
53-
import torch
54-
from PIL import Image
55-
56-
from diffusers.image_processor import VaeImageProcessor
57-
from diffusers.video_processor import VideoProcessor
58-
from safetensors.torch import _tobytes
59-
60-
DTYPE_MAP = {
61-
"float16": torch.float16,
62-
"float32": torch.float32,
63-
"bfloat16": torch.bfloat16,
64-
"uint8": torch.uint8,
65-
}
66-
67-
68-
def remote_decode(
69-
endpoint: str,
70-
tensor: torch.Tensor,
71-
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None,
72-
do_scaling: bool = True,
73-
output_type: Literal["mp4", "pil", "pt"] = "pil",
74-
image_format: Literal["png", "jpg"] = "jpg",
75-
partial_postprocess: bool = False,
76-
input_tensor_type: Literal["base64", "binary"] = "base64",
77-
output_tensor_type: Literal["base64", "binary"] = "base64",
78-
height: Optional[int] = None,
79-
width: Optional[int] = None,
80-
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]:
81-
if tensor.ndim == 3 and height is None and width is None:
82-
raise ValueError("`height` and `width` required for packed latents.")
83-
if output_type == "pt" and partial_postprocess is False and processor is None:
84-
raise ValueError(
85-
"`processor` is required with `output_type='pt'` and `partial_postprocess=False`."
86-
)
87-
headers = {}
88-
parameters = {
89-
"do_scaling": do_scaling,
90-
"output_type": output_type,
91-
"partial_postprocess": partial_postprocess,
92-
"shape": list(tensor.shape),
93-
"dtype": str(tensor.dtype).split(".")[-1],
94-
}
95-
if height is not None and width is not None:
96-
parameters["height"] = height
97-
parameters["width"] = width
98-
tensor_data = _tobytes(tensor, "tensor")
99-
if input_tensor_type == "base64":
100-
headers["Content-Type"] = "tensor/base64"
101-
elif input_tensor_type == "binary":
102-
headers["Content-Type"] = "tensor/binary"
103-
if output_type == "pil" and image_format == "jpg" and processor is None:
104-
headers["Accept"] = "image/jpeg"
105-
elif output_type == "pil" and image_format == "png" and processor is None:
106-
headers["Accept"] = "image/png"
107-
elif (output_tensor_type == "base64" and output_type == "pt") or (
108-
output_tensor_type == "base64"
109-
and output_type == "pil"
110-
and processor is not None
111-
):
112-
headers["Accept"] = "tensor/base64"
113-
elif (output_tensor_type == "binary" and output_type == "pt") or (
114-
output_tensor_type == "binary"
115-
and output_type == "pil"
116-
and processor is not None
117-
):
118-
headers["Accept"] = "tensor/binary"
119-
elif output_type == "mp4":
120-
headers["Accept"] = "text/plain"
121-
if input_tensor_type == "base64":
122-
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
123-
elif input_tensor_type == "binary":
124-
kwargs = {"data": tensor_data}
125-
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
126-
if not response.ok:
127-
raise RuntimeError(response.json())
128-
if output_type == "pt" or (output_type == "pil" and processor is not None):
129-
if output_tensor_type == "base64":
130-
content = response.json()
131-
output_tensor = base64.b64decode(content["inputs"])
132-
parameters = content["parameters"]
133-
shape = parameters["shape"]
134-
dtype = parameters["dtype"]
135-
elif output_tensor_type == "binary":
136-
output_tensor = response.content
137-
parameters = response.headers
138-
shape = json.loads(parameters["shape"])
139-
dtype = parameters["dtype"]
140-
torch_dtype = DTYPE_MAP[dtype]
141-
output_tensor = torch.frombuffer(
142-
bytearray(output_tensor), dtype=torch_dtype
143-
).reshape(shape)
144-
if output_type == "pt":
145-
if partial_postprocess:
146-
output = [Image.fromarray(image.numpy()) for image in output_tensor]
147-
if len(output) == 1:
148-
output = output[0]
149-
else:
150-
if processor is None:
151-
output = output_tensor
152-
else:
153-
if isinstance(processor, VideoProcessor):
154-
output = cast(
155-
List[Image.Image],
156-
processor.postprocess_video(output_tensor, output_type="pil")[0],
157-
)
158-
else:
159-
output = cast(
160-
Image.Image,
161-
processor.postprocess(output_tensor, output_type="pil")[0],
162-
)
163-
elif output_type == "pil" and processor is None:
164-
output = Image.open(io.BytesIO(response.content)).convert("RGB")
165-
elif output_type == "pil" and processor is not None:
166-
output = [
167-
Image.fromarray(image)
168-
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255)
169-
.round()
170-
.astype("uint8")
171-
]
172-
elif output_type == "mp4":
173-
output = response.content
174-
return output
49+
from diffusers.utils.remote_utils import remote_decode
17550
```
17651

17752
</p>
@@ -188,6 +63,7 @@ Here, we show how to use the remote VAE on random tensors.
18863
image = remote_decode(
18964
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
19065
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
66+
scaling_factor=0.18215,
19167
)
19268
```
19369

@@ -209,6 +85,8 @@ image = remote_decode(
20985
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
21086
height=1024,
21187
width=1024,
88+
scaling_factor=0.3611,
89+
shift_factor=0.1159,
21290
)
21391
```
21492

@@ -246,70 +124,6 @@ with open("video.mp4", "wb") as f:
246124
</video>
247125
</figure>
248126

249-
### Options
250-
251-
Let's review the available options.
252-
253-
```python
254-
def remote_decode(
255-
endpoint: str,
256-
tensor: torch.Tensor,
257-
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None,
258-
do_scaling: bool = True,
259-
output_type: Literal["mp4", "pil", "pt"] = "pil",
260-
image_format: Literal["png", "jpg"] = "jpg",
261-
partial_postprocess: bool = False,
262-
input_tensor_type: Literal["base64", "binary"] = "base64",
263-
output_tensor_type: Literal["base64", "binary"] = "base64",
264-
height: Optional[int] = None,
265-
width: Optional[int] = None,
266-
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]:
267-
```
268-
269-
#### Overview of decoding
270-
271-
There are 3 parts of decoding in a pipeline: `scaling` -> `decode` -> `postprocess`.
272-
273-
Options allow Remote VAE to be compatible with these different stages.
274-
275-
#### `processor`
276-
277-
With `output_type="pt"` the endpoint returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally.
278-
279-
With `output_type="pil"` on video models `processor=VideoProcessor()` is required for some local postprocessing.
280-
281-
#### `do_scaling`
282-
283-
- `do_scaling=False` allows Remote VAE to work as a drop-in replacement for `pipe.vae.decode`. Scaling should be applied to input before `remote_decode`.
284-
- `do_scaling=True` scaling is applied by Remote VAE.
285-
286-
#### `output_type`
287-
288-
Image models support: `pil`, `pt`.
289-
290-
Video models support: `mp4`, `pil`, `pt`.
291-
292-
- `output_type="pil"` returns an image according to `image_format` for Image models and a tensor for Video models (equivalent to `postprocess_video(frames, output_type="pt")`) which has final postprocessing applied to create the frame images.
293-
- `output_type="pt"` with `partial_postprocess=False` returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally.
294-
- `output_type="pt"` with `partial_postprocess=True` returns a `torch.Tensor` with `postprocess` applied. The final image creation (`PIL.Image.fromarray`) is done locally. This reduces transfer compared to `partial_postprocess=False`.
295-
- `output_type="mp4"` applies `postprocess_video(frames, output_type="pil")` then `export_to_video` and returns `bytes` of the `mp4`.
296-
297-
#### `input_tensor_type`/`output_tensor_type`
298-
299-
Choices `base64`, `binary`.
300-
301-
Using `binary` reduces transfer.
302-
303-
#### `image_format`
304-
305-
Choices `jpg`, `png`.
306-
307-
`jpg` is faster but lower quality.
308-
309-
#### `height`/`width`
310-
311-
Required for packed latents in Flux. Not required with `do_scaling=False` as `unpack` occurs before scaling.
312-
313127

314128
### Generation
315129

@@ -337,6 +151,7 @@ latent = pipe(
337151
image = remote_decode(
338152
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
339153
tensor=latent,
154+
scaling_factor=0.18215,
340155
)
341156
image.save("test.jpg")
342157
```
@@ -375,6 +190,8 @@ image = remote_decode(
375190
tensor=latent,
376191
height=1024,
377192
width=1024,
193+
scaling_factor=0.3611,
194+
shift_factor=0.1159,
378195
)
379196
image.save("test.jpg")
380197
```
@@ -456,6 +273,7 @@ def decode_worker(q: queue.Queue):
456273
image = remote_decode(
457274
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
458275
tensor=item,
276+
scaling_factor=0.18215,
459277
)
460278
display(image)
461279
q.task_done()

0 commit comments

Comments
 (0)