Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ def set_default_attn_processor(self):

self.set_attn_processor(processor)

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self.tiled_encode(x)

enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)

return enc

@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
Expand All @@ -261,21 +273,13 @@ def encode(
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)

if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
h = self._encode(x)

if self.quant_conv is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think algorithem changed a bit for use_slicing
previously, apply quant_conv once after combining encoder outputs from all slice
currently, apply quant_conv on each slice

I'm pretty sure the result would be the same, I wonder if there is any implication on performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the performance should be the same since just one convolution layer on compressed outputs of encoder. I can get some numbers soon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could perhaps add a test to ensure this? That should clear the confusions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w do you think it could make sense add a fast test here or not really?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay without a test here. The functionality is effectively similar and only affects the "batch_size" dim across this conv layer - which will never alter outputs as conv doesn't operate on that.

I know that understanding the changes here is quite easy, but I feel I should leave a comment making the explanation a bit more clear and elaborate for anyone stumbling upon this in the future.

Previously, slicing worked individually and tiling worked individually. When both were enabled, only tiling would be in effect meaning it would chop [B, C, H, W] to 4 tiles of shape [B, C, H // 2, W // 2] (assuming we have 2x2 perfect tiles), process each tile individually and perform blending.

This is incorrect as slicing is completely ignored. The correct processing size, ensuring slicing also took effect, would be 4 x B tiles with shape [1, C, H // 2, W // 2] - which this PR does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

moments = self.quant_conv(h)
else:
moments = h

posterior = DiagonalGaussianDistribution(moments)
posterior = DiagonalGaussianDistribution(h)

if not return_dict:
return (posterior,)
Expand Down Expand Up @@ -337,7 +341,7 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, if we concern breaking, we can deprecate tiled_encode and make a new one called _tiled_encode

Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, actually prefer to do that, I do see some usage of vae.titled_encode() https://github.com/search?q=%22pipe.tiled_encode%22+OR+%22vae.tiled_encode%22+OR+%22pipeline.tiled_encode%22&type=code ; also our current implementation of titled_encode is something can be used on its own, the new one is more like a private method that has to be called inside _encode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll make a new method

r"""Encode a batch of images using a tiled encoder.

When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
Expand All @@ -348,13 +352,10 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder

Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
Expand Down Expand Up @@ -384,13 +385,8 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)

if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)
enc = torch.cat(result_rows, dim=2)
return enc

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Expand Down