Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def create_forward(*inputs):
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
Copy link
Member Author

@sayakpaul sayakpaul Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the torch.utils.checkpoint.checkpoint() method doesn't have any conv_cache argument.

)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -531,7 +531,7 @@ def create_forward(*inputs):
return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -649,7 +649,7 @@ def create_forward(*inputs):
hidden_states,
temb,
zq,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -789,7 +789,7 @@ def custom_forward(*inputs):
hidden_states,
temb,
None,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)

# 2. Mid
Expand All @@ -798,14 +798,14 @@ def custom_forward(*inputs):
hidden_states,
temb,
None,
conv_cache=conv_cache.get("mid_block"),
conv_cache.get("mid_block"),
)
else:
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
hidden_states, temb, None, conv_cache.get(conv_cache_key)
)

# 2. Mid
Expand Down Expand Up @@ -953,7 +953,7 @@ def custom_forward(*inputs):
hidden_states,
temb,
sample,
conv_cache=conv_cache.get("mid_block"),
conv_cache.get("mid_block"),
)

# 2. Up
Expand All @@ -964,7 +964,7 @@ def custom_forward(*inputs):
hidden_states,
temb,
sample,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
# 1. Mid
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def forward(
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
dec = self.decode(z).sample
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we return a tuple of DecoderOutput when return_dict=False.

if not return_dict:
return (dec,)
return dec
return DecoderOutput(sample=dec)
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,6 @@ def __init__(

self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)

sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
Comment on lines -232 to -238
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused.


def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def decode(
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = [
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
]
Comment on lines -313 to +315
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use x_slice and not x.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe further refactor this to how the current implementations of Cog/Mochi are with _decode method. A bit easier to understand code flow that way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure feel free to club those in your PR.

output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
Expand Down Expand Up @@ -341,7 +343,7 @@ def forward(
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)

dec = self.decode(unscaled_enc)
dec = self.decode(unscaled_enc).sample

if not return_dict:
return (dec,)
Expand Down
261 changes: 261 additions & 0 deletions tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch
from parameterized import parameterized

from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,
)

from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2

def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"down_block_out_channels": block_out_channels,
"layers_per_down_block": 1,
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"up_block_out_channels": block_out_channels,
"layers_per_up_block": 1,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"sample_size": 32,
"scaling_factor": 0.18215,
}
return init_dict

@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)

return {"sample": image, "mask": mask}

@property
def input_shape(self):
return (3, 32, 32)

@property
def output_shape(self):
return (3, 32, 32)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_asym_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass


@slow
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image

def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
revision = "main"
torch_dtype = torch.float32

model = AsymmetricAutoencoderKL.from_pretrained(
model_id,
torch_dtype=torch_dtype,
revision=revision,
)
model.to(torch_device).eval()

return model

def get_generator(self, seed=0):
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
return torch.Generator(device=generator_device).manual_seed(seed)
return torch.manual_seed(seed)

@parameterized.expand(
[
# fmt: off
[
33,
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
],
[
47,
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
],
# fmt: on
]
)
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)
generator = self.get_generator(seed)

with torch.no_grad():
sample = model(image, generator=generator, sample_posterior=True).sample

assert sample.shape == image.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)

assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

@parameterized.expand(
[
# fmt: off
[
33,
[-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
[-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
],
[
47,
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
],
# fmt: on
]
)
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)

with torch.no_grad():
sample = model(image).sample

assert sample.shape == image.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)

assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)

@parameterized.expand(
[
# fmt: off
[13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
# fmt: on
]
)
@require_torch_accelerator
@skip_mps
def test_stable_diffusion_decode(self, seed, expected_slice):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

with torch.no_grad():
sample = model.decode(encoding).sample

assert list(sample.shape) == [3, 3, 512, 512]

output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)

@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

with torch.no_grad():
sample = model.decode(encoding).sample

model.enable_xformers_memory_efficient_attention()
with torch.no_grad():
sample_2 = model.decode(encoding).sample

assert list(sample.shape) == [3, 3, 512, 512]

assert torch_all_close(sample, sample_2, atol=5e-2)

@parameterized.expand(
[
# fmt: off
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
# fmt: on
]
)
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)
generator = self.get_generator(seed)

with torch.no_grad():
dist = model.encode(image).latent_dist
sample = dist.sample(generator=generator)

assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]

output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
expected_output_slice = torch.tensor(expected_slice)

tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
Loading
Loading