Skip to content

Commit 2e2684f

Browse files
thomasedingasomoza
andauthored
Add vae_roundtrip.py example (#7104)
* Add vae_roundtrip.py example * Add cuda support to vae_roundtrip * Move vae_roundtrip.py into research_projects/vae * Fix channel scaling in vae roundrip and also support taesd. * Apply ruff --fix for CI gatekeep check --------- Co-authored-by: Álvaro Somoza <[email protected]>
1 parent 31adeb4 commit 2e2684f

File tree

2 files changed

+293
-0
lines changed

2 files changed

+293
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# VAE
2+
3+
`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.
4+
5+
```
6+
cd examples/research_projects/vae
7+
python vae_roundtrip.py \
8+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
9+
--subfolder="vae" \
10+
--input_image="/path/to/your/input.png"
11+
```
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
16+
import argparse
17+
import typing
18+
from typing import Optional, Union
19+
20+
import torch
21+
from PIL import Image
22+
from torchvision import transforms # type: ignore
23+
24+
from diffusers.image_processor import VaeImageProcessor
25+
from diffusers.models.autoencoders.autoencoder_kl import (
26+
AutoencoderKL,
27+
AutoencoderKLOutput,
28+
)
29+
from diffusers.models.autoencoders.autoencoder_tiny import (
30+
AutoencoderTiny,
31+
AutoencoderTinyOutput,
32+
)
33+
from diffusers.models.autoencoders.vae import DecoderOutput
34+
35+
36+
SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]
37+
38+
39+
def load_vae_model(
40+
*,
41+
device: torch.device,
42+
model_name_or_path: str,
43+
revision: Optional[str],
44+
variant: Optional[str],
45+
# NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
46+
subfolder: Optional[str],
47+
use_tiny_nn: bool,
48+
) -> SupportedAutoencoder:
49+
if use_tiny_nn:
50+
# NOTE: These scaling factors don't have to be the same as each other.
51+
down_scale = 2
52+
up_scale = 2
53+
vae = AutoencoderTiny.from_pretrained( # type: ignore
54+
model_name_or_path,
55+
subfolder=subfolder,
56+
revision=revision,
57+
variant=variant,
58+
downscaling_scaling_factor=down_scale,
59+
upsampling_scaling_factor=up_scale,
60+
)
61+
assert isinstance(vae, AutoencoderTiny)
62+
else:
63+
vae = AutoencoderKL.from_pretrained( # type: ignore
64+
model_name_or_path,
65+
subfolder=subfolder,
66+
revision=revision,
67+
variant=variant,
68+
)
69+
assert isinstance(vae, AutoencoderKL)
70+
vae = vae.to(device)
71+
vae.eval() # Set the model to inference mode
72+
return vae
73+
74+
75+
def pil_to_nhwc(
76+
*,
77+
device: torch.device,
78+
image: Image.Image,
79+
) -> torch.Tensor:
80+
assert image.mode == "RGB"
81+
transform = transforms.ToTensor()
82+
nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
83+
assert isinstance(nhwc, torch.Tensor)
84+
return nhwc
85+
86+
87+
def nhwc_to_pil(
88+
*,
89+
nhwc: torch.Tensor,
90+
) -> Image.Image:
91+
assert nhwc.shape[0] == 1
92+
hwc = nhwc.squeeze(0).cpu()
93+
return transforms.ToPILImage()(hwc) # type: ignore
94+
95+
96+
def concatenate_images(
97+
*,
98+
left: Image.Image,
99+
right: Image.Image,
100+
vertical: bool = False,
101+
) -> Image.Image:
102+
width1, height1 = left.size
103+
width2, height2 = right.size
104+
if vertical:
105+
total_height = height1 + height2
106+
max_width = max(width1, width2)
107+
new_image = Image.new("RGB", (max_width, total_height))
108+
new_image.paste(left, (0, 0))
109+
new_image.paste(right, (0, height1))
110+
else:
111+
total_width = width1 + width2
112+
max_height = max(height1, height2)
113+
new_image = Image.new("RGB", (total_width, max_height))
114+
new_image.paste(left, (0, 0))
115+
new_image.paste(right, (width1, 0))
116+
return new_image
117+
118+
119+
def to_latent(
120+
*,
121+
rgb_nchw: torch.Tensor,
122+
vae: SupportedAutoencoder,
123+
) -> torch.Tensor:
124+
rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
125+
encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
126+
if isinstance(encoding_nchw, AutoencoderKLOutput):
127+
latent = encoding_nchw.latent_dist.sample() # type: ignore
128+
assert isinstance(latent, torch.Tensor)
129+
elif isinstance(encoding_nchw, AutoencoderTinyOutput):
130+
latent = encoding_nchw.latents
131+
do_internal_vae_scaling = False # Is this needed?
132+
if do_internal_vae_scaling:
133+
latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
134+
latent = vae.unscale_latents(latent / 255.0) # type: ignore
135+
assert isinstance(latent, torch.Tensor)
136+
else:
137+
assert False, f"Unknown encoding type: {type(encoding_nchw)}"
138+
return latent
139+
140+
141+
def from_latent(
142+
*,
143+
latent_nchw: torch.Tensor,
144+
vae: SupportedAutoencoder,
145+
) -> torch.Tensor:
146+
decoding_nchw = vae.decode(latent_nchw) # type: ignore
147+
assert isinstance(decoding_nchw, DecoderOutput)
148+
rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
149+
assert isinstance(rgb_nchw, torch.Tensor)
150+
return rgb_nchw
151+
152+
153+
def main_kwargs(
154+
*,
155+
device: torch.device,
156+
input_image_path: str,
157+
pretrained_model_name_or_path: str,
158+
revision: Optional[str],
159+
variant: Optional[str],
160+
subfolder: Optional[str],
161+
use_tiny_nn: bool,
162+
) -> None:
163+
vae = load_vae_model(
164+
device=device,
165+
model_name_or_path=pretrained_model_name_or_path,
166+
revision=revision,
167+
variant=variant,
168+
subfolder=subfolder,
169+
use_tiny_nn=use_tiny_nn,
170+
)
171+
original_pil = Image.open(input_image_path).convert("RGB")
172+
original_image = pil_to_nhwc(
173+
device=device,
174+
image=original_pil,
175+
)
176+
print(f"Original image shape: {original_image.shape}")
177+
reconstructed_image: Optional[torch.Tensor] = None
178+
179+
with torch.no_grad():
180+
latent_image = to_latent(rgb_nchw=original_image, vae=vae)
181+
print(f"Latent shape: {latent_image.shape}")
182+
reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
183+
reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
184+
combined_image = concatenate_images(
185+
left=original_pil,
186+
right=reconstructed_pil,
187+
vertical=False,
188+
)
189+
combined_image.show("Original | Reconstruction")
190+
print(f"Reconstructed image shape: {reconstructed_image.shape}")
191+
192+
193+
def parse_args() -> argparse.Namespace:
194+
parser = argparse.ArgumentParser(description="Inference with VAE")
195+
parser.add_argument(
196+
"--input_image",
197+
type=str,
198+
required=True,
199+
help="Path to the input image for inference.",
200+
)
201+
parser.add_argument(
202+
"--pretrained_model_name_or_path",
203+
type=str,
204+
required=True,
205+
help="Path to pretrained VAE model.",
206+
)
207+
parser.add_argument(
208+
"--revision",
209+
type=str,
210+
default=None,
211+
help="Model version.",
212+
)
213+
parser.add_argument(
214+
"--variant",
215+
type=str,
216+
default=None,
217+
help="Model file variant, e.g., 'fp16'.",
218+
)
219+
parser.add_argument(
220+
"--subfolder",
221+
type=str,
222+
default=None,
223+
help="Subfolder in the model file.",
224+
)
225+
parser.add_argument(
226+
"--use_cuda",
227+
action="store_true",
228+
help="Use CUDA if available.",
229+
)
230+
parser.add_argument(
231+
"--use_tiny_nn",
232+
action="store_true",
233+
help="Use tiny neural network.",
234+
)
235+
return parser.parse_args()
236+
237+
238+
# EXAMPLE USAGE:
239+
#
240+
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
241+
#
242+
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
243+
#
244+
def main_cli() -> None:
245+
args = parse_args()
246+
247+
input_image_path = args.input_image
248+
assert isinstance(input_image_path, str)
249+
250+
pretrained_model_name_or_path = args.pretrained_model_name_or_path
251+
assert isinstance(pretrained_model_name_or_path, str)
252+
253+
revision = args.revision
254+
assert isinstance(revision, (str, type(None)))
255+
256+
variant = args.variant
257+
assert isinstance(variant, (str, type(None)))
258+
259+
subfolder = args.subfolder
260+
assert isinstance(subfolder, (str, type(None)))
261+
262+
use_cuda = args.use_cuda
263+
assert isinstance(use_cuda, bool)
264+
265+
use_tiny_nn = args.use_tiny_nn
266+
assert isinstance(use_tiny_nn, bool)
267+
268+
device = torch.device("cuda" if use_cuda else "cpu")
269+
270+
main_kwargs(
271+
device=device,
272+
input_image_path=input_image_path,
273+
pretrained_model_name_or_path=pretrained_model_name_or_path,
274+
revision=revision,
275+
variant=variant,
276+
subfolder=subfolder,
277+
use_tiny_nn=use_tiny_nn,
278+
)
279+
280+
281+
if __name__ == "__main__":
282+
main_cli()

0 commit comments

Comments
 (0)