From 55a1abd497915f59a29da910c385eb5ab31235d8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Sep 2024 12:37:33 +0200 Subject: [PATCH 1/6] bugfix: precedence of operations should be slicing -> tiling --- .../models/autoencoders/autoencoder_kl.py | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 161770c67cf8..3e9eab4705e0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -245,6 +245,18 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(x) + + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -261,21 +273,13 @@ def encode( The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): - return self.tiled_encode(x, return_dict=return_dict) - if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self.encoder(x) + h = self._encode(x) - if self.quant_conv is not None: - moments = self.quant_conv(h) - else: - moments = h - - posterior = DiagonalGaussianDistribution(moments) + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) @@ -337,7 +341,7 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b - def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several @@ -348,13 +352,10 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder Args: x (`torch.Tensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: - [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: - If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain - `tuple` is returned. + `torch.Tensor`: + The latent representation of the encoded videos. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) @@ -384,13 +385,8 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) - moments = torch.cat(result_rows, dim=2) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) + enc = torch.cat(result_rows, dim=2) + return enc def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" From 93e4d2340735a1327271c175a8e694df44ee97be Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Sep 2024 13:58:14 +0200 Subject: [PATCH 2/6] fix typo --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 3e9eab4705e0..214717a56178 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -253,7 +253,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: enc = self.encoder(x) if self.quant_conv is not None: - enc = self.quant_conv(x) + enc = self.quant_conv(enc) return enc From 5ac6473874bc15eb7254c40f50725826c0678ea9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 2 Sep 2024 15:22:33 +0200 Subject: [PATCH 3/6] fix another typo --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 214717a56178..a3b22ee992f0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -248,7 +248,7 @@ def set_default_attn_processor(self): def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape - if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self.tiled_encode(x) enc = self.encoder(x) From ea25b6918420bcce5d557f9cc38ef151ba234d3f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 25 Sep 2024 01:28:13 +0200 Subject: [PATCH 4/6] deprecate current implementation of tiled_encode and use new impl --- .../models/autoencoders/autoencoder_kl.py | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index a3b22ee992f0..a3caeddb8035 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -18,6 +18,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import deprecate from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -249,7 +250,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): - return self.tiled_encode(x) + return self._tiled_encode(x) enc = self.encoder(x) if self.quant_conv is not None: @@ -341,7 +342,7 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b - def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several @@ -357,6 +358,13 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: `torch.Tensor`: The latent representation of the encoded videos. """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent @@ -388,6 +396,68 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: enc = torch.cat(result_rows, dim=2) return enc + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. From 02068221b6631b61b2dcf1361baebcb6f97fb633 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 26 Sep 2024 05:59:27 +0530 Subject: [PATCH 5/6] Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu --- src/diffusers/models/autoencoders/autoencoder_kl.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index a3caeddb8035..49542ddfa83a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -358,12 +358,6 @@ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: `torch.Tensor`: The latent representation of the encoded videos. """ - deprecation_message = ( - "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " - "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " - "to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." - ) - deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) From 896c8f7ee50a9514d4836c384548b03a4f82ab19 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 26 Sep 2024 05:59:55 +0530 Subject: [PATCH 6/6] Update src/diffusers/models/autoencoders/autoencoder_kl.py --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 49542ddfa83a..99a7da4a0b6f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -412,7 +412,7 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder deprecation_message = ( "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " - "to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." ) deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)