-
Couldn't load subscription status.
- Fork 6.5k
[tests] refactor vae tests #9808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
8c3e871
8c5c543
409caef
6178385
a9de100
4b5b4b0
97bab4c
9819f8a
11ed2fd
86c0353
9c67948
ee835bc
a157dae
4d0bae2
19515bc
786fd5b
d67e63b
ac8f7e1
bfe23ff
364659b
a400bc9
c79170d
d55ed94
fb8aa88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -433,7 +433,7 @@ def create_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| zq, | ||
| conv_cache=conv_cache.get(conv_cache_key), | ||
| conv_cache.get(conv_cache_key), | ||
| ) | ||
| else: | ||
| hidden_states, new_conv_cache[conv_cache_key] = resnet( | ||
|
|
@@ -531,7 +531,7 @@ def create_forward(*inputs): | |
| return create_forward | ||
|
|
||
| hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) | ||
| create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key) | ||
| ) | ||
| else: | ||
| hidden_states, new_conv_cache[conv_cache_key] = resnet( | ||
|
|
@@ -649,7 +649,7 @@ def create_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| zq, | ||
| conv_cache=conv_cache.get(conv_cache_key), | ||
| conv_cache.get(conv_cache_key), | ||
| ) | ||
| else: | ||
| hidden_states, new_conv_cache[conv_cache_key] = resnet( | ||
|
|
@@ -789,7 +789,7 @@ def custom_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| None, | ||
| conv_cache=conv_cache.get(conv_cache_key), | ||
| conv_cache.get(conv_cache_key), | ||
| ) | ||
|
|
||
| # 2. Mid | ||
|
|
@@ -798,14 +798,14 @@ def custom_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| None, | ||
| conv_cache=conv_cache.get("mid_block"), | ||
| conv_cache.get("mid_block"), | ||
| ) | ||
| else: | ||
| # 1. Down | ||
| for i, down_block in enumerate(self.down_blocks): | ||
| conv_cache_key = f"down_block_{i}" | ||
| hidden_states, new_conv_cache[conv_cache_key] = down_block( | ||
| hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key) | ||
| hidden_states, temb, None, conv_cache.get(conv_cache_key) | ||
| ) | ||
|
|
||
| # 2. Mid | ||
|
|
@@ -953,7 +953,7 @@ def custom_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| sample, | ||
| conv_cache=conv_cache.get("mid_block"), | ||
| conv_cache.get("mid_block"), | ||
| ) | ||
|
|
||
| # 2. Up | ||
|
|
@@ -964,7 +964,7 @@ def custom_forward(*inputs): | |
| hidden_states, | ||
| temb, | ||
| sample, | ||
| conv_cache=conv_cache.get(conv_cache_key), | ||
| conv_cache.get(conv_cache_key), | ||
| ) | ||
| else: | ||
| # 1. Mid | ||
|
|
@@ -1475,7 +1475,7 @@ def forward( | |
| z = posterior.sample(generator=generator) | ||
| else: | ||
| z = posterior.mode() | ||
| dec = self.decode(z) | ||
| dec = self.decode(z).sample | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise we return a tuple of |
||
| if not return_dict: | ||
| return (dec,) | ||
| return dec | ||
| return DecoderOutput(sample=dec) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -228,14 +228,6 @@ def __init__( | |
|
|
||
| self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) | ||
|
|
||
| sample_size = ( | ||
| self.config.sample_size[0] | ||
| if isinstance(self.config.sample_size, (list, tuple)) | ||
| else self.config.sample_size | ||
| ) | ||
| self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) | ||
| self.tile_overlap_factor = 0.25 | ||
|
Comment on lines
-232
to
-238
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused. |
||
|
|
||
| def _set_gradient_checkpointing(self, module, value=False): | ||
| if isinstance(module, (Encoder, TemporalDecoder)): | ||
| module.gradient_checkpointing = value | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -310,7 +310,9 @@ def decode( | |
| self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True | ||
| ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: | ||
| if self.use_slicing and x.shape[0] > 1: | ||
| output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] | ||
| output = [ | ||
| self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1) | ||
| ] | ||
|
Comment on lines
-313
to
+315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could maybe further refactor this to how the current implementations of Cog/Mochi are with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah sure feel free to club those in your PR. |
||
| output = torch.cat(output) | ||
| else: | ||
| output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) | ||
|
|
@@ -341,7 +343,7 @@ def forward( | |
| # as if we were loading the latents from an RGBA uint8 image. | ||
| unscaled_enc = self.unscale_latents(scaled_enc / 255.0) | ||
|
|
||
| dec = self.decode(unscaled_enc) | ||
| dec = self.decode(unscaled_enc).sample | ||
|
|
||
| if not return_dict: | ||
| return (dec,) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,261 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2024 HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from diffusers import AsymmetricAutoencoderKL | ||
| from diffusers.utils.import_utils import is_xformers_available | ||
| from diffusers.utils.testing_utils import ( | ||
| backend_empty_cache, | ||
| enable_full_determinism, | ||
| floats_tensor, | ||
| load_hf_numpy, | ||
| require_torch_accelerator, | ||
| require_torch_gpu, | ||
| skip_mps, | ||
| slow, | ||
| torch_all_close, | ||
| torch_device, | ||
| ) | ||
|
|
||
| from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin | ||
|
|
||
|
|
||
| enable_full_determinism() | ||
|
|
||
|
|
||
| class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | ||
| model_class = AsymmetricAutoencoderKL | ||
| main_input_name = "sample" | ||
| base_precision = 1e-2 | ||
|
|
||
| def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): | ||
| block_out_channels = block_out_channels or [2, 4] | ||
| norm_num_groups = norm_num_groups or 2 | ||
| init_dict = { | ||
| "in_channels": 3, | ||
| "out_channels": 3, | ||
| "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), | ||
| "down_block_out_channels": block_out_channels, | ||
| "layers_per_down_block": 1, | ||
| "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), | ||
| "up_block_out_channels": block_out_channels, | ||
| "layers_per_up_block": 1, | ||
| "act_fn": "silu", | ||
| "latent_channels": 4, | ||
| "norm_num_groups": norm_num_groups, | ||
| "sample_size": 32, | ||
| "scaling_factor": 0.18215, | ||
| } | ||
| return init_dict | ||
|
|
||
| @property | ||
| def dummy_input(self): | ||
| batch_size = 4 | ||
| num_channels = 3 | ||
| sizes = (32, 32) | ||
|
|
||
| image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) | ||
| mask = torch.ones((batch_size, 1) + sizes).to(torch_device) | ||
|
|
||
| return {"sample": image, "mask": mask} | ||
|
|
||
| @property | ||
| def input_shape(self): | ||
| return (3, 32, 32) | ||
|
|
||
| @property | ||
| def output_shape(self): | ||
| return (3, 32, 32) | ||
|
|
||
| def prepare_init_args_and_inputs_for_common(self): | ||
| init_dict = self.get_asym_autoencoder_kl_config() | ||
| inputs_dict = self.dummy_input | ||
| return init_dict, inputs_dict | ||
|
|
||
| @unittest.skip("Unsupported test.") | ||
| def test_forward_with_norm_groups(self): | ||
| pass | ||
|
|
||
|
|
||
| @slow | ||
| class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): | ||
| def get_file_format(self, seed, shape): | ||
| return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" | ||
|
|
||
| def tearDown(self): | ||
| # clean up the VRAM after each test | ||
| super().tearDown() | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
||
| def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): | ||
| dtype = torch.float16 if fp16 else torch.float32 | ||
| image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) | ||
| return image | ||
|
|
||
| def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False): | ||
| revision = "main" | ||
| torch_dtype = torch.float32 | ||
|
|
||
| model = AsymmetricAutoencoderKL.from_pretrained( | ||
| model_id, | ||
| torch_dtype=torch_dtype, | ||
| revision=revision, | ||
| ) | ||
| model.to(torch_device).eval() | ||
|
|
||
| return model | ||
|
|
||
| def get_generator(self, seed=0): | ||
| generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" | ||
| if torch_device != "mps": | ||
| return torch.Generator(device=generator_device).manual_seed(seed) | ||
| return torch.manual_seed(seed) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| # fmt: off | ||
| [ | ||
| 33, | ||
| [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205], | ||
| [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], | ||
| ], | ||
| [ | ||
| 47, | ||
| [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], | ||
| [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], | ||
| ], | ||
| # fmt: on | ||
| ] | ||
| ) | ||
| def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps): | ||
| model = self.get_sd_vae_model() | ||
| image = self.get_sd_image(seed) | ||
| generator = self.get_generator(seed) | ||
|
|
||
| with torch.no_grad(): | ||
| sample = model(image, generator=generator, sample_posterior=True).sample | ||
|
|
||
| assert sample.shape == image.shape | ||
|
|
||
| output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() | ||
| expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice) | ||
|
|
||
| assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| # fmt: off | ||
| [ | ||
| 33, | ||
| [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097], | ||
| [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], | ||
| ], | ||
| [ | ||
| 47, | ||
| [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], | ||
| [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], | ||
| ], | ||
| # fmt: on | ||
| ] | ||
| ) | ||
| def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps): | ||
| model = self.get_sd_vae_model() | ||
| image = self.get_sd_image(seed) | ||
|
|
||
| with torch.no_grad(): | ||
| sample = model(image).sample | ||
|
|
||
| assert sample.shape == image.shape | ||
|
|
||
| output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() | ||
| expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice) | ||
|
|
||
| assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| # fmt: off | ||
| [13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]], | ||
| [37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]], | ||
| # fmt: on | ||
| ] | ||
| ) | ||
| @require_torch_accelerator | ||
| @skip_mps | ||
| def test_stable_diffusion_decode(self, seed, expected_slice): | ||
| model = self.get_sd_vae_model() | ||
| encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) | ||
|
|
||
| with torch.no_grad(): | ||
| sample = model.decode(encoding).sample | ||
|
|
||
| assert list(sample.shape) == [3, 3, 512, 512] | ||
|
|
||
| output_slice = sample[-1, -2:, :2, -2:].flatten().cpu() | ||
| expected_output_slice = torch.tensor(expected_slice) | ||
|
|
||
| assert torch_all_close(output_slice, expected_output_slice, atol=2e-3) | ||
|
|
||
| @parameterized.expand([(13,), (16,), (37,)]) | ||
| @require_torch_gpu | ||
| @unittest.skipIf( | ||
| not is_xformers_available(), | ||
| reason="xformers is not required when using PyTorch 2.0.", | ||
| ) | ||
| def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): | ||
| model = self.get_sd_vae_model() | ||
| encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) | ||
|
|
||
| with torch.no_grad(): | ||
| sample = model.decode(encoding).sample | ||
|
|
||
| model.enable_xformers_memory_efficient_attention() | ||
| with torch.no_grad(): | ||
| sample_2 = model.decode(encoding).sample | ||
|
|
||
| assert list(sample.shape) == [3, 3, 512, 512] | ||
|
|
||
| assert torch_all_close(sample, sample_2, atol=5e-2) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| # fmt: off | ||
| [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]], | ||
| [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]], | ||
| # fmt: on | ||
| ] | ||
| ) | ||
| def test_stable_diffusion_encode_sample(self, seed, expected_slice): | ||
| model = self.get_sd_vae_model() | ||
| image = self.get_sd_image(seed) | ||
| generator = self.get_generator(seed) | ||
|
|
||
| with torch.no_grad(): | ||
| dist = model.encode(image).latent_dist | ||
| sample = dist.sample(generator=generator) | ||
|
|
||
| assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]] | ||
|
|
||
| output_slice = sample[0, -1, -3:, -3:].flatten().cpu() | ||
| expected_output_slice = torch.tensor(expected_slice) | ||
|
|
||
| tolerance = 3e-3 if torch_device != "mps" else 1e-2 | ||
| assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the
torch.utils.checkpoint.checkpoint()method doesn't have anyconv_cacheargument.