3939from ..modular_pipeline_utils import ComponentSpec , ConfigSpec , InputParam , OutputParam
4040from .modular_pipeline import StableDiffusionXLModularPipeline
4141
42+ from PIL import Image
43+
4244
4345logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4446
@@ -504,11 +506,11 @@ def encode_prompt(
504506 negative_prompt_embeds = torch .concat (negative_prompt_embeds_list , dim = - 1 )
505507 negative_pooled_prompt_embeds = torch .concat (negative_pooled_prompt_embeds_list , dim = 0 )
506508
507- prompt_embeds = prompt_embeds .to (dtype , device = device )
508- pooled_prompt_embeds = pooled_prompt_embeds .to (dtype , device = device )
509+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
510+ pooled_prompt_embeds = pooled_prompt_embeds .to (dtype = dtype , device = device )
509511 if requires_unconditional_embeds :
510- negative_prompt_embeds = negative_prompt_embeds .to (dtype , device = device )
511- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds .to (dtype , device = device )
512+ negative_prompt_embeds = negative_prompt_embeds .to (dtype = dtype , device = device )
513+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds .to (dtype = dtype , device = device )
512514
513515 for text_encoder in text_encoders :
514516 if isinstance (components , StableDiffusionXLLoraLoaderMixin ) and USE_PEFT_BACKEND :
@@ -687,12 +689,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
687689
688690 def check_inputs (self , image , mask_image , padding_mask_crop ):
689691
690- if padding_mask_crop is not None and not isinstance (image , PIL . Image .Image ):
692+ if padding_mask_crop is not None and not isinstance (image , Image .Image ):
691693 raise ValueError (
692694 f"The image should be a PIL image when inpainting mask crop, but is of type { type (image )} ."
693695 )
694696
695- if padding_mask_crop is not None and not isinstance (mask_image , PIL . Image .Image ):
697+ if padding_mask_crop is not None and not isinstance (mask_image , Image .Image ):
696698 raise ValueError (
697699 f"The mask image should be a PIL image when inpainting mask crop, but is of type"
698700 f" { type (mask_image )} ."
@@ -707,10 +709,8 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
707709 dtype = block_state .dtype if block_state .dtype is not None else components .vae .dtype
708710 device = components ._execution_device
709711
710- if block_state .height is None :
711- height = components .default_height
712- if block_state .width is None :
713- width = components .default_width
712+ height = block_state .height if block_state .height is not None else components .default_height
713+ width = block_state .width if block_state .width is not None else components .default_width
714714
715715 if block_state .padding_mask_crop is not None :
716716 block_state .crops_coords = components .mask_processor .get_crop_region (
@@ -725,21 +725,21 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
725725 block_state .image ,
726726 height = height ,
727727 width = width ,
728- crops_coords = crops_coords ,
728+ crops_coords = block_state . crops_coords ,
729729 resize_mode = resize_mode ,
730730 )
731731
732732 image = image .to (dtype = torch .float32 )
733733
734- mask = components .mask_processor .preprocess (
734+ mask_image = components .mask_processor .preprocess (
735735 block_state .mask_image ,
736736 height = height ,
737737 width = width ,
738738 resize_mode = resize_mode ,
739- crops_coords = crops_coords ,
739+ crops_coords = block_state . crops_coords ,
740740 )
741741
742- masked_image = image * (block_state . mask_latents < 0.5 )
742+ masked_image = image * (mask_image < 0.5 )
743743
744744 # Prepare image latent variables
745745 block_state .image_latents = encode_vae_image (
@@ -762,7 +762,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
762762 # resize mask to match the image latents
763763 _ , _ , height_latents , width_latents = block_state .image_latents .shape
764764 block_state .mask = torch .nn .functional .interpolate (
765- mask ,
765+ mask_image ,
766766 size = (height_latents , width_latents ),
767767 )
768768 block_state .mask = block_state .mask .to (dtype = dtype , device = device )
0 commit comments