@@ -663,12 +663,11 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
663663 block_state .device = components ._execution_device
664664 block_state .dtype = block_state .dtype if block_state .dtype is not None else components .vae .dtype
665665
666- block_state . image = components .image_processor .preprocess (
666+ image = components .image_processor .preprocess (
667667 block_state .image , height = block_state .height , width = block_state .width , ** block_state .preprocess_kwargs
668668 )
669- block_state .image = block_state .image .to (device = block_state .device , dtype = block_state .dtype )
670-
671- block_state .batch_size = block_state .image .shape [0 ]
669+ image = image .to (device = block_state .device , dtype = block_state .dtype )
670+ block_state .batch_size = image .shape [0 ]
672671
673672 # if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
674673 if isinstance (block_state .generator , list ) and len (block_state .generator ) != block_state .batch_size :
@@ -677,9 +676,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
677676 f" size of { block_state .batch_size } . Make sure the batch size matches the length of the generators."
678677 )
679678
680- block_state .image_latents = self ._encode_vae_image (
681- components , image = block_state .image , generator = block_state .generator
682- )
679+ block_state .image_latents = self ._encode_vae_image (components , image = image , generator = block_state .generator )
683680
684681 self .set_block_state (state , block_state )
685682
@@ -850,34 +847,32 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
850847 block_state .crops_coords = None
851848 block_state .resize_mode = "default"
852849
853- block_state . image = components .image_processor .preprocess (
850+ image = components .image_processor .preprocess (
854851 block_state .image ,
855852 height = block_state .height ,
856853 width = block_state .width ,
857854 crops_coords = block_state .crops_coords ,
858855 resize_mode = block_state .resize_mode ,
859856 )
860- block_state . image = block_state . image .to (dtype = torch .float32 )
857+ image = image .to (dtype = torch .float32 )
861858
862- block_state . mask = components .mask_processor .preprocess (
859+ mask = components .mask_processor .preprocess (
863860 block_state .mask_image ,
864861 height = block_state .height ,
865862 width = block_state .width ,
866863 resize_mode = block_state .resize_mode ,
867864 crops_coords = block_state .crops_coords ,
868865 )
869- block_state .masked_image = block_state . image * (block_state . mask < 0.5 )
866+ block_state .masked_image = image * (mask < 0.5 )
870867
871- block_state .batch_size = block_state .image .shape [0 ]
872- block_state .image = block_state .image .to (device = block_state .device , dtype = block_state .dtype )
873- block_state .image_latents = self ._encode_vae_image (
874- components , image = block_state .image , generator = block_state .generator
875- )
868+ block_state .batch_size = image .shape [0 ]
869+ image = image .to (device = block_state .device , dtype = block_state .dtype )
870+ block_state .image_latents = self ._encode_vae_image (components , image = image , generator = block_state .generator )
876871
877872 # 7. Prepare mask latent variables
878873 block_state .mask , block_state .masked_image_latents = self .prepare_mask_latents (
879874 components ,
880- block_state . mask ,
875+ mask ,
881876 block_state .masked_image ,
882877 block_state .batch_size ,
883878 block_state .height ,
0 commit comments