Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

What does this PR do?

Discovered in internal discussions on Slack.

Currently, if you enable both slicing and tiling, only tiling occurs in AutoencoderKL. This is incorrect because it means that memory usage will not be constant if batch size changes, which could be the case with many production applications built on top of diffusers. This PR fixes the behaviour by ensuring that slicing takes precedence over tiling similar to the vae.decode method.

This change is a bit backwards breaking in terms of the return_dict parameter in tiled_encode but should be safe, I think. Some reference usage can be found here: https://github.com/search?q=%22pipe.tiled_encode%22+OR+%22vae.tiled_encode%22+OR+%22pipeline.tiled_encode%22&type=code

The reason for removing the return_dict parameter is because it creates unnecessary complication and introduces multiple branches to handle posterior distribution correctly depending on whether tiling is enable or not. I gave this some thought when making changes in #9340 and don't really see a clean way to address this.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6

@a-r-r-o-w a-r-r-o-w requested a review from DN6 September 2, 2024 10:45
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR,
I left a question!

h = self.encoder(x)
h = self._encode(x)

if self.quant_conv is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think algorithem changed a bit for use_slicing
previously, apply quant_conv once after combining encoder outputs from all slice
currently, apply quant_conv on each slice

I'm pretty sure the result would be the same, I wonder if there is any implication on performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the performance should be the same since just one convolution layer on compressed outputs of encoder. I can get some numbers soon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could perhaps add a test to ensure this? That should clear the confusions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w do you think it could make sense add a fast test here or not really?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay without a test here. The functionality is effectively similar and only affects the "batch_size" dim across this conv layer - which will never alter outputs as conv doesn't operate on that.

I know that understanding the changes here is quite easy, but I feel I should leave a comment making the explanation a bit more clear and elaborate for anyone stumbling upon this in the future.

Previously, slicing worked individually and tiling worked individually. When both were enabled, only tiling would be in effect meaning it would chop [B, C, H, W] to 4 tiles of shape [B, C, H // 2, W // 2] (assuming we have 2x2 perfect tiles), process each tile individually and perform blending.

This is incorrect as slicing is completely ignored. The correct processing size, ensuring slicing also took effect, would be 4 x B tiles with shape [1, C, H // 2, W // 2] - which this PR does.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

@a-r-r-o-w
Copy link
Contributor Author

After a couple of runs of the following code, I'm actually seeing that this branch is about 0.1-0.3 seconds faster than diffusers:main for a batch size of 8, but it may just be random fluctuations. But, the memory savings is correct and as expected in this branch as compared to main.

import gc
import random

import numpy as np
import torch
from diffusers import AutoencoderKL


def reset_memory(device):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)


def print_memory(device):
    max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
    max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
    print(f"{max_memory=:.2f}")
    print(f"{max_reserved=:.2f}")


@torch.no_grad()
def main():
    device = "cuda"
    dtype = torch.float16

    reset_memory(device)

    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
    vae.to(device, dtype=dtype)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    input_1 = torch.randn((1, 3, 1024, 1024), device=device, dtype=dtype)
    input_8 = torch.randn((8, 3, 1024, 1024), device=device, dtype=dtype)

    # Warmup
    for _ in range(3):
        _encode = vae.encode(input_1).latent_dist.sample()
        _decode = vae.decode(_encode).sample

    torch.cuda.synchronize(device)
    del _encode, _decode
    reset_memory(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    encode_original_1 = vae.encode(input_1).latent_dist.sample(generator=torch.Generator().manual_seed(seed))
    decode_original_1 = vae.decode(encode_original_1).sample
    end.record()
    torch.cuda.synchronize(device)

    print("===== encode-decode-1 =====")
    print(f"Time: {start.elapsed_time(end):.3f}")
    print_memory(device)
    reset_memory(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    encode_original_8 = vae.encode(input_8).latent_dist.sample(generator=torch.Generator().manual_seed(seed))
    decode_original_8 = vae.decode(encode_original_8).sample
    end.record()
    torch.cuda.synchronize(device)

    print("===== encode-decode-8 =====")
    print(f"Time: {start.elapsed_time(end):.3f}")
    print_memory(device)
    reset_memory(device)

    vae.enable_slicing()
    vae.enable_tiling()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    encode_enabled_1 = vae.encode(input_1).latent_dist.sample(generator=torch.Generator().manual_seed(seed))
    decode_enabled_1 = vae.decode(encode_enabled_1).sample
    end.record()
    torch.cuda.synchronize(device)

    print("===== encode-decode-slicing-tiling-1 =====")
    print(f"Time: {start.elapsed_time(end):.3f}")
    print_memory(device)
    reset_memory(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    encode_enabled_8 = vae.encode(input_8).latent_dist.sample(generator=torch.Generator().manual_seed(seed))
    decode_enabled_8 = vae.decode(encode_enabled_8).sample
    end.record()
    torch.cuda.synchronize(device)

    print("===== encode-decode-slicing-tiling-8 =====")
    print(f"Time: {start.elapsed_time(end):.3f}")
    print_memory(device)
    reset_memory(device)


if __name__ == "__main__":
    main()

diffusers:main:

===== encode-decode-1 =====
Time: 1063.304
max_memory=2.60
max_reserved=3.83
===== encode-decode-4 =====
Time: 2660.009
max_memory=16.23
max_reserved=26.77
===== encode-decode-slicing-tiling-1 =====
Time: 2375.430
max_memory=0.44
max_reserved=2.27
===== encode-decode-slicing-tiling-4 =====
Time: 7469.336
max_memory=1.03
max_reserved=2.27

This branch:

===== encode-decode-1 =====
Time: 1080.947
max_memory=2.60
max_reserved=3.83
===== encode-decode-8 =====
Time: 2546.400
max_memory=16.23
max_reserved=26.77
===== encode-decode-slicing-tiling-1 =====
Time: 2353.985
max_memory=0.44
max_reserved=2.27
===== encode-decode-slicing-tiling-8 =====
Time: 7105.924
max_memory=0.49
max_reserved=2.27

@a-r-r-o-w
Copy link
Contributor Author

@yiyixuxu @sayakpaul Gentle ping

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single comment

h = self.encoder(x)
h = self._encode(x)

if self.quant_conv is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could perhaps add a test to ensure this? That should clear the confusions?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

return b

def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, if we concern breaking, we can deprecate tiled_encode and make a new one called _tiled_encode

Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, actually prefer to do that, I do see some usage of vae.titled_encode() https://github.com/search?q=%22pipe.tiled_encode%22+OR+%22vae.tiled_encode%22+OR+%22pipeline.tiled_encode%22&type=code ; also our current implementation of titled_encode is something can be used on its own, the new one is more like a private method that has to be called inside _encode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll make a new method

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu September 24, 2024 23:39
Comment on lines 361 to 366
deprecation_message = (
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
"to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value."
)
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test this deprecation too.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!
left one comment - feel free to merge once that's addressed

@a-r-r-o-w a-r-r-o-w merged commit 665c6b4 into main Sep 26, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the vae/bugfix-slicing-tiling branch September 26, 2024 16:42
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
…ggingface#9342)

* bugfix: precedence of operations should be slicing -> tiling

* fix typo

* fix another typo

* deprecate current implementation of tiled_encode and use new impl

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
)

* bugfix: precedence of operations should be slicing -> tiling

* fix typo

* fix another typo

* deprecate current implementation of tiled_encode and use new impl

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants