8181 >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
8282 ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
8383 ... )
84+ >>> pipe.enable_vae_tiling()
8485 >>> pipe = pipe.to("cuda")
8586
8687
9091
9192
9293 >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93- >>> image = download_image(img_url)
94+ >>> image = download_image(img_url).resize((1024, 1024))
9495
9596 >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
9697
@@ -197,7 +198,7 @@ def __init__(self, device):
197198
198199 # The gaussian kernel is the product of the gaussian function of each dimension.
199200 kernel = 1
200- meshgrids = torch .meshgrid ([torch .arange (size , dtype = torch .float32 ) for size in kernel_size ])
201+ meshgrids = torch .meshgrid ([torch .arange (size , dtype = torch .float32 ) for size in kernel_size ], indexing = "ij" )
201202 for size , std , mgrid in zip (kernel_size , sigma , meshgrids ):
202203 mean = (size - 1 ) / 2
203204 kernel *= 1 / (std * math .sqrt (2 * math .pi )) * torch .exp (- (((mgrid - mean ) / (2 * std )) ** 2 ))
@@ -768,6 +769,35 @@ def denoising_end(self):
768769 def num_timesteps (self ):
769770 return self ._num_timesteps
770771
772+ def enable_vae_slicing (self ):
773+ r"""
774+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
775+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
776+ """
777+ self .vae .enable_slicing ()
778+
779+ def disable_vae_slicing (self ):
780+ r"""
781+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
782+ computing decoding in one step.
783+ """
784+ self .vae .disable_slicing ()
785+
786+ def enable_vae_tiling (self ):
787+ r"""
788+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
789+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
790+ processing larger images.
791+ """
792+ self .vae .enable_tiling ()
793+
794+ def disable_vae_tiling (self ):
795+ r"""
796+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
797+ computing decoding in one step.
798+ """
799+ self .vae .disable_tiling ()
800+
771801 # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
772802 def prepare_unet (self , attention_store , PnP : bool = False ):
773803 attn_procs = {}
@@ -1401,6 +1431,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
14011431 image = self .image_processor .preprocess (
14021432 image = image , height = height , width = width , resize_mode = resize_mode , crops_coords = crops_coords
14031433 )
1434+ height , width = image .shape [- 2 :]
1435+ if height % 32 != 0 or width % 32 != 0 :
1436+ raise ValueError (
1437+ "Image height and width must be a factor of 32. "
1438+ "Consider down-sampling the input using the `height` and `width` parameters"
1439+ )
14041440 resized = self .image_processor .postprocess (image = image , output_type = "pil" )
14051441
14061442 if max (image .shape [- 2 :]) > self .vae .config ["sample_size" ] * 1.5 :
@@ -1439,6 +1475,10 @@ def invert(
14391475 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
14401476 num_zero_noise_steps : int = 3 ,
14411477 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1478+ height : Optional [int ] = None ,
1479+ width : Optional [int ] = None ,
1480+ resize_mode : Optional [str ] = "default" ,
1481+ crops_coords : Optional [Tuple [int , int , int , int ]] = None ,
14421482 ):
14431483 r"""
14441484 The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1526,8 @@ def invert(
14861526 [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
14871527 and respective VAE reconstruction(s).
14881528 """
1529+ if height % 32 != 0 or width % 32 != 0 :
1530+ raise ValueError ("height and width must be a factor of 32." )
14891531
14901532 # Reset attn processor, we do not want to store attn maps during inversion
14911533 self .unet .set_attn_processor (AttnProcessor ())
@@ -1510,7 +1552,14 @@ def invert(
15101552 do_classifier_free_guidance = source_guidance_scale > 1.0
15111553
15121554 # 1. prepare image
1513- x0 , resized = self .encode_image (image , dtype = self .text_encoder_2 .dtype )
1555+ x0 , resized = self .encode_image (
1556+ image ,
1557+ dtype = self .text_encoder_2 .dtype ,
1558+ height = height ,
1559+ width = width ,
1560+ resize_mode = resize_mode ,
1561+ crops_coords = crops_coords ,
1562+ )
15141563 width = x0 .shape [2 ] * self .vae_scale_factor
15151564 height = x0 .shape [3 ] * self .vae_scale_factor
15161565 self .size = (height , width )
0 commit comments