Skip to content

Commit 9ee94f9

Browse files
committed
LEditsPP - examples, check height/width, add tiling/slicing
1 parent b572635 commit 9ee94f9

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@
4545
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
4646
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
4747
... )
48+
>>> pipe.enable_vae_tiling()
4849
>>> pipe = pipe.to("cuda")
4950
5051
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
51-
>>> image = load_image(img_url).convert("RGB")
52+
>>> image = load_image(img_url).convert("RGB").resize((512, 512))
5253
5354
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
5455
@@ -152,7 +153,7 @@ def __init__(self, device):
152153

153154
# The gaussian kernel is the product of the gaussian function of each dimension.
154155
kernel = 1
155-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
156+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
156157
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
157158
mean = (size - 1) / 2
158159
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -706,6 +707,35 @@ def clip_skip(self):
706707
def cross_attention_kwargs(self):
707708
return self._cross_attention_kwargs
708709

710+
def enable_vae_slicing(self):
711+
r"""
712+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
713+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
714+
"""
715+
self.vae.enable_slicing()
716+
717+
def disable_vae_slicing(self):
718+
r"""
719+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
720+
computing decoding in one step.
721+
"""
722+
self.vae.disable_slicing()
723+
724+
def enable_vae_tiling(self):
725+
r"""
726+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
727+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
728+
processing larger images.
729+
"""
730+
self.vae.enable_tiling()
731+
732+
def disable_vae_tiling(self):
733+
r"""
734+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
735+
computing decoding in one step.
736+
"""
737+
self.vae.disable_tiling()
738+
709739
@torch.no_grad()
710740
@replace_example_docstring(EXAMPLE_DOC_STRING)
711741
def __call__(
@@ -1271,6 +1301,8 @@ def invert(
12711301
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
12721302
and respective VAE reconstruction(s).
12731303
"""
1304+
if height % 32 != 0 or width % 32 != 0:
1305+
raise ValueError("height and width must be a factor of 32.")
12741306
# Reset attn processor, we do not want to store attn maps during inversion
12751307
self.unet.set_attn_processor(AttnProcessor())
12761308

@@ -1360,6 +1392,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
13601392
image = self.image_processor.preprocess(
13611393
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
13621394
)
1395+
height, width = image.shape[-2:]
1396+
if height % 32 != 0 or width % 32 != 0:
1397+
raise ValueError(
1398+
"Image height and width must be a factor of 32. "
1399+
"Consider down-sampling the input using the `height` and `width` parameters"
1400+
)
13631401
resized = self.image_processor.postprocess(image=image, output_type="pil")
13641402

13651403
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
8282
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
8383
... )
84+
>>> pipe.enable_vae_tiling()
8485
>>> pipe = pipe.to("cuda")
8586
8687
@@ -90,7 +91,7 @@
9091
9192
9293
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93-
>>> image = download_image(img_url)
94+
>>> image = download_image(img_url).resize((1024, 1024))
9495
9596
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
9697
@@ -197,7 +198,7 @@ def __init__(self, device):
197198

198199
# The gaussian kernel is the product of the gaussian function of each dimension.
199200
kernel = 1
200-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
201+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
201202
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
202203
mean = (size - 1) / 2
203204
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -768,6 +769,35 @@ def denoising_end(self):
768769
def num_timesteps(self):
769770
return self._num_timesteps
770771

772+
def enable_vae_slicing(self):
773+
r"""
774+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
775+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
776+
"""
777+
self.vae.enable_slicing()
778+
779+
def disable_vae_slicing(self):
780+
r"""
781+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
782+
computing decoding in one step.
783+
"""
784+
self.vae.disable_slicing()
785+
786+
def enable_vae_tiling(self):
787+
r"""
788+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
789+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
790+
processing larger images.
791+
"""
792+
self.vae.enable_tiling()
793+
794+
def disable_vae_tiling(self):
795+
r"""
796+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
797+
computing decoding in one step.
798+
"""
799+
self.vae.disable_tiling()
800+
771801
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
772802
def prepare_unet(self, attention_store, PnP: bool = False):
773803
attn_procs = {}
@@ -1401,6 +1431,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
14011431
image = self.image_processor.preprocess(
14021432
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
14031433
)
1434+
height, width = image.shape[-2:]
1435+
if height % 32 != 0 or width % 32 != 0:
1436+
raise ValueError(
1437+
"Image height and width must be a factor of 32. "
1438+
"Consider down-sampling the input using the `height` and `width` parameters"
1439+
)
14041440
resized = self.image_processor.postprocess(image=image, output_type="pil")
14051441

14061442
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
@@ -1439,6 +1475,10 @@ def invert(
14391475
crops_coords_top_left: Tuple[int, int] = (0, 0),
14401476
num_zero_noise_steps: int = 3,
14411477
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1478+
height: Optional[int] = None,
1479+
width: Optional[int] = None,
1480+
resize_mode: Optional[str] = "default",
1481+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
14421482
):
14431483
r"""
14441484
The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1526,8 @@ def invert(
14861526
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
14871527
and respective VAE reconstruction(s).
14881528
"""
1529+
if height % 32 != 0 or width % 32 != 0:
1530+
raise ValueError("height and width must be a factor of 32.")
14891531

14901532
# Reset attn processor, we do not want to store attn maps during inversion
14911533
self.unet.set_attn_processor(AttnProcessor())
@@ -1510,7 +1552,14 @@ def invert(
15101552
do_classifier_free_guidance = source_guidance_scale > 1.0
15111553

15121554
# 1. prepare image
1513-
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
1555+
x0, resized = self.encode_image(
1556+
image,
1557+
dtype=self.text_encoder_2.dtype,
1558+
height=height,
1559+
width=width,
1560+
resize_mode=resize_mode,
1561+
crops_coords=crops_coords,
1562+
)
15141563
width = x0.shape[2] * self.vae_scale_factor
15151564
height = x0.shape[3] * self.vae_scale_factor
15161565
self.size = (height, width)

0 commit comments

Comments
 (0)