Skip to content

Commit d4b8ef3

Browse files
committed
LEditsPP - examples, check height/width, add tiling/slicing
1 parent b572635 commit d4b8ef3

File tree

2 files changed

+99
-19
lines changed

2 files changed

+99
-19
lines changed

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@
3434
EXAMPLE_DOC_STRING = """
3535
Examples:
3636
```py
37-
>>> import PIL
38-
>>> import requests
3937
>>> import torch
40-
>>> from io import BytesIO
4138
4239
>>> from diffusers import LEditsPPPipelineStableDiffusion
4340
>>> from diffusers.utils import load_image
4441
4542
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
46-
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
43+
... "runwayml/stable-diffusion-v1-5",
44+
... variant="fp16",
45+
... torch_dtype=torch.float16
4746
... )
47+
>>> pipe.enable_vae_tiling()
4848
>>> pipe = pipe.to("cuda")
4949
5050
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
51-
>>> image = load_image(img_url).convert("RGB")
51+
>>> image = load_image(img_url).resize((512, 512))
5252
5353
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
5454
@@ -152,7 +152,7 @@ def __init__(self, device):
152152

153153
# The gaussian kernel is the product of the gaussian function of each dimension.
154154
kernel = 1
155-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
155+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
156156
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
157157
mean = (size - 1) / 2
158158
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -706,6 +706,35 @@ def clip_skip(self):
706706
def cross_attention_kwargs(self):
707707
return self._cross_attention_kwargs
708708

709+
def enable_vae_slicing(self):
710+
r"""
711+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
712+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
713+
"""
714+
self.vae.enable_slicing()
715+
716+
def disable_vae_slicing(self):
717+
r"""
718+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
719+
computing decoding in one step.
720+
"""
721+
self.vae.disable_slicing()
722+
723+
def enable_vae_tiling(self):
724+
r"""
725+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
726+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
727+
processing larger images.
728+
"""
729+
self.vae.enable_tiling()
730+
731+
def disable_vae_tiling(self):
732+
r"""
733+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
734+
computing decoding in one step.
735+
"""
736+
self.vae.disable_tiling()
737+
709738
@torch.no_grad()
710739
@replace_example_docstring(EXAMPLE_DOC_STRING)
711740
def __call__(
@@ -1271,6 +1300,8 @@ def invert(
12711300
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
12721301
and respective VAE reconstruction(s).
12731302
"""
1303+
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
1304+
raise ValueError("height and width must be a factor of 32.")
12741305
# Reset attn processor, we do not want to store attn maps during inversion
12751306
self.unet.set_attn_processor(AttnProcessor())
12761307

@@ -1360,6 +1391,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
13601391
image = self.image_processor.preprocess(
13611392
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
13621393
)
1394+
height, width = image.shape[-2:]
1395+
if height % 32 != 0 or width % 32 != 0:
1396+
raise ValueError(
1397+
"Image height and width must be a factor of 32. "
1398+
"Consider down-sampling the input using the `height` and `width` parameters"
1399+
)
13631400
resized = self.image_processor.postprocess(image=image, output_type="pil")
13641401

13651402
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,20 @@
7272
Examples:
7373
```py
7474
>>> import torch
75-
>>> import PIL
76-
>>> import requests
77-
>>> from io import BytesIO
7875
7976
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
77+
>>> from diffusers.utils import load_image
8078
8179
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
82-
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
80+
... "stabilityai/stable-diffusion-xl-base-1.0",
81+
... variant="fp16",
82+
... torch_dtype=torch.float16
8383
... )
84+
>>> pipe.enable_vae_tiling()
8485
>>> pipe = pipe.to("cuda")
8586
86-
87-
>>> def download_image(url):
88-
... response = requests.get(url)
89-
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
90-
91-
9287
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93-
>>> image = download_image(img_url)
88+
>>> image = load_image(img_url).resize((1024, 1024))
9489
9590
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
9691
@@ -197,7 +192,7 @@ def __init__(self, device):
197192

198193
# The gaussian kernel is the product of the gaussian function of each dimension.
199194
kernel = 1
200-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
195+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
201196
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
202197
mean = (size - 1) / 2
203198
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -768,6 +763,35 @@ def denoising_end(self):
768763
def num_timesteps(self):
769764
return self._num_timesteps
770765

766+
def enable_vae_slicing(self):
767+
r"""
768+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
769+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
770+
"""
771+
self.vae.enable_slicing()
772+
773+
def disable_vae_slicing(self):
774+
r"""
775+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
776+
computing decoding in one step.
777+
"""
778+
self.vae.disable_slicing()
779+
780+
def enable_vae_tiling(self):
781+
r"""
782+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
783+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
784+
processing larger images.
785+
"""
786+
self.vae.enable_tiling()
787+
788+
def disable_vae_tiling(self):
789+
r"""
790+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
791+
computing decoding in one step.
792+
"""
793+
self.vae.disable_tiling()
794+
771795
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
772796
def prepare_unet(self, attention_store, PnP: bool = False):
773797
attn_procs = {}
@@ -1401,6 +1425,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
14011425
image = self.image_processor.preprocess(
14021426
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
14031427
)
1428+
height, width = image.shape[-2:]
1429+
if height % 32 != 0 or width % 32 != 0:
1430+
raise ValueError(
1431+
"Image height and width must be a factor of 32. "
1432+
"Consider down-sampling the input using the `height` and `width` parameters"
1433+
)
14041434
resized = self.image_processor.postprocess(image=image, output_type="pil")
14051435

14061436
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
@@ -1439,6 +1469,10 @@ def invert(
14391469
crops_coords_top_left: Tuple[int, int] = (0, 0),
14401470
num_zero_noise_steps: int = 3,
14411471
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1472+
height: Optional[int] = None,
1473+
width: Optional[int] = None,
1474+
resize_mode: Optional[str] = "default",
1475+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
14421476
):
14431477
r"""
14441478
The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1520,8 @@ def invert(
14861520
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
14871521
and respective VAE reconstruction(s).
14881522
"""
1523+
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
1524+
raise ValueError("height and width must be a factor of 32.")
14891525

14901526
# Reset attn processor, we do not want to store attn maps during inversion
14911527
self.unet.set_attn_processor(AttnProcessor())
@@ -1510,7 +1546,14 @@ def invert(
15101546
do_classifier_free_guidance = source_guidance_scale > 1.0
15111547

15121548
# 1. prepare image
1513-
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
1549+
x0, resized = self.encode_image(
1550+
image,
1551+
dtype=self.text_encoder_2.dtype,
1552+
height=height,
1553+
width=width,
1554+
resize_mode=resize_mode,
1555+
crops_coords=crops_coords,
1556+
)
15141557
width = x0.shape[2] * self.vae_scale_factor
15151558
height = x0.shape[3] * self.vae_scale_factor
15161559
self.size = (height, width)

0 commit comments

Comments
 (0)