Skip to content

Commit 8c5c543

Browse files
committed
autoencodertiny.
1 parent 8c3e871 commit 8c5c543

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
import torch
20+
from parameterized import parameterized
21+
22+
from diffusers import AutoencoderTiny
23+
from diffusers.utils.testing_utils import (
24+
backend_empty_cache,
25+
enable_full_determinism,
26+
floats_tensor,
27+
load_hf_numpy,
28+
slow,
29+
torch_all_close,
30+
torch_device,
31+
)
32+
33+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
34+
35+
36+
enable_full_determinism()
37+
38+
39+
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
40+
model_class = AutoencoderTiny
41+
main_input_name = "sample"
42+
base_precision = 1e-2
43+
44+
def get_autoencoder_tiny_config(self, block_out_channels=None):
45+
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
46+
init_dict = {
47+
"in_channels": 3,
48+
"out_channels": 3,
49+
"encoder_block_out_channels": block_out_channels,
50+
"decoder_block_out_channels": block_out_channels,
51+
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
52+
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
53+
}
54+
return init_dict
55+
56+
@property
57+
def dummy_input(self):
58+
batch_size = 4
59+
num_channels = 3
60+
sizes = (32, 32)
61+
62+
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
63+
64+
return {"sample": image}
65+
66+
@property
67+
def input_shape(self):
68+
return (3, 32, 32)
69+
70+
@property
71+
def output_shape(self):
72+
return (3, 32, 32)
73+
74+
def prepare_init_args_and_inputs_for_common(self):
75+
init_dict = self.get_autoencoder_tiny_config()
76+
inputs_dict = self.dummy_input
77+
return init_dict, inputs_dict
78+
79+
@unittest.skip("Model doesn't yet support smaller resolution.")
80+
def test_enable_disable_tiling(self):
81+
pass
82+
83+
def test_enable_disable_slicing(self):
84+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
85+
86+
torch.manual_seed(0)
87+
model = self.model_class(**init_dict).to(torch_device)
88+
89+
inputs_dict.update({"return_dict": False})
90+
91+
torch.manual_seed(0)
92+
output_without_slicing = model(**inputs_dict)[0]
93+
94+
torch.manual_seed(0)
95+
model.enable_slicing()
96+
output_with_slicing = model(**inputs_dict)[0]
97+
98+
self.assertLess(
99+
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
100+
0.5,
101+
"VAE slicing should not affect the inference results",
102+
)
103+
104+
torch.manual_seed(0)
105+
model.disable_slicing()
106+
output_without_slicing_2 = model(**inputs_dict)[0]
107+
108+
self.assertEqual(
109+
output_without_slicing.detach().cpu().numpy().all(),
110+
output_without_slicing_2.detach().cpu().numpy().all(),
111+
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
112+
)
113+
114+
@unittest.skip("Test not supported.")
115+
def test_outputs_equivalence(self):
116+
pass
117+
118+
@unittest.skip("Test not supported.")
119+
def test_forward_with_norm_groups(self):
120+
pass
121+
122+
123+
@slow
124+
class AutoencoderTinyIntegrationTests(unittest.TestCase):
125+
def tearDown(self):
126+
# clean up the VRAM after each test
127+
super().tearDown()
128+
gc.collect()
129+
backend_empty_cache(torch_device)
130+
131+
def get_file_format(self, seed, shape):
132+
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
133+
134+
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
135+
dtype = torch.float16 if fp16 else torch.float32
136+
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
137+
return image
138+
139+
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
140+
torch_dtype = torch.float16 if fp16 else torch.float32
141+
142+
model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
143+
model.to(torch_device).eval()
144+
return model
145+
146+
@parameterized.expand(
147+
[
148+
[(1, 4, 73, 97), (1, 3, 584, 776)],
149+
[(1, 4, 97, 73), (1, 3, 776, 584)],
150+
[(1, 4, 49, 65), (1, 3, 392, 520)],
151+
[(1, 4, 65, 49), (1, 3, 520, 392)],
152+
[(1, 4, 49, 49), (1, 3, 392, 392)],
153+
]
154+
)
155+
def test_tae_tiling(self, in_shape, out_shape):
156+
model = self.get_sd_vae_model()
157+
model.enable_tiling()
158+
with torch.no_grad():
159+
zeros = torch.zeros(in_shape).to(torch_device)
160+
dec = model.decode(zeros).sample
161+
assert dec.shape == out_shape
162+
163+
def test_stable_diffusion(self):
164+
model = self.get_sd_vae_model()
165+
image = self.get_sd_image(seed=33)
166+
167+
with torch.no_grad():
168+
sample = model(image).sample
169+
170+
assert sample.shape == image.shape
171+
172+
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
173+
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
174+
175+
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
176+
177+
@parameterized.expand([(True,), (False,)])
178+
def test_tae_roundtrip(self, enable_tiling):
179+
# load the autoencoder
180+
model = self.get_sd_vae_model()
181+
if enable_tiling:
182+
model.enable_tiling()
183+
184+
# make a black image with a white square in the middle,
185+
# which is large enough to split across multiple tiles
186+
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
187+
image[..., 256:768, 256:768] = 1.0
188+
189+
# round-trip the image through the autoencoder
190+
with torch.no_grad():
191+
sample = model(image).sample
192+
193+
# the autoencoder reconstruction should match original image, sorta
194+
def downscale(x):
195+
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
196+
197+
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)

0 commit comments

Comments
 (0)