4040 AttnProcessor2_0 ,
4141 XFormersAttnProcessor ,
4242)
43- from ...models .controlnet_union import ControlNetUnionInput , ControlNetUnionInputProMax
43+ from ...models .controlnets import ControlNetUnionInput , ControlNetUnionInputProMax
4444from ...models .lora import adjust_lora_scale_text_encoder
4545from ...schedulers import KarrasDiffusionSchedulers
4646from ...utils import (
@@ -605,6 +605,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
605605 extra_step_kwargs ["generator" ] = generator
606606 return extra_step_kwargs
607607
608+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
608609 def check_image (self , image , prompt , prompt_embeds ):
609610 image_is_pil = isinstance (image , PIL .Image .Image )
610611 image_is_tensor = isinstance (image , torch .Tensor )
@@ -826,6 +827,7 @@ def check_inputs(
826827 f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
827828 )
828829
830+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
829831 def prepare_control_image (
830832 self ,
831833 image ,
@@ -860,6 +862,7 @@ def prepare_control_image(
860862
861863 return image
862864
865+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_latents
863866 def prepare_latents (
864867 self ,
865868 batch_size ,
@@ -927,6 +930,7 @@ def prepare_latents(
927930
928931 return outputs
929932
933+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline._encode_vae_image
930934 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
931935 dtype = image .dtype
932936 if self .vae .config .force_upcast :
@@ -950,6 +954,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
950954
951955 return image_latents
952956
957+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_mask_latents
953958 def prepare_mask_latents (
954959 self , mask , masked_image , batch_size , height , width , dtype , device , generator , do_classifier_free_guidance
955960 ):
@@ -1560,7 +1565,7 @@ def denoising_value_valid(dnv):
15601565 latents , noise = latents_outputs
15611566
15621567 # 7. Prepare mask latent variables
1563- mask , masked_image_latents = self .prepare_mask_latents (
1568+ mask , _ = self .prepare_mask_latents (
15641569 mask ,
15651570 masked_image ,
15661571 batch_size * num_images_per_prompt ,
@@ -1573,19 +1578,7 @@ def denoising_value_valid(dnv):
15731578 )
15741579
15751580 # 8. Check that sizes of mask, masked image and latents match
1576- if num_channels_unet == 9 :
1577- # default case for runwayml/stable-diffusion-inpainting
1578- num_channels_mask = mask .shape [1 ]
1579- num_channels_masked_image = masked_image_latents .shape [1 ]
1580- if num_channels_latents + num_channels_mask + num_channels_masked_image != self .unet .config .in_channels :
1581- raise ValueError (
1582- f"Incorrect configuration settings! The config of `pipeline.unet`: { self .unet .config } expects"
1583- f" { self .unet .config .in_channels } but received `num_channels_latents`: { num_channels_latents } +"
1584- f" `num_channels_mask`: { num_channels_mask } + `num_channels_masked_image`: { num_channels_masked_image } "
1585- f" = { num_channels_latents + num_channels_masked_image + num_channels_mask } . Please verify the config of"
1586- " `pipeline.unet` or your `mask_image` or `image` input."
1587- )
1588- elif num_channels_unet != 4 :
1581+ if num_channels_unet != 4 :
15891582 raise ValueError (
15901583 f"The unet { self .unet .__class__ } should have either 4 or 9 input channels, not { self .unet .config .in_channels } ."
15911584 )
@@ -1673,7 +1666,6 @@ def denoising_value_valid(dnv):
16731666 # expand the latents if we are doing classifier free guidance
16741667 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
16751668
1676- # concat latents, mask, masked_image_latents in the channel dimension
16771669 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
16781670
16791671 added_cond_kwargs = {
@@ -1730,9 +1722,6 @@ def denoising_value_valid(dnv):
17301722 if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
17311723 added_cond_kwargs ["image_embeds" ] = image_embeds
17321724
1733- if num_channels_unet == 9 :
1734- latent_model_input = torch .cat ([latent_model_input , mask , masked_image_latents ], dim = 1 )
1735-
17361725 # predict the noise residual
17371726 noise_pred = self .unet (
17381727 latent_model_input ,
@@ -1757,20 +1746,19 @@ def denoising_value_valid(dnv):
17571746 # compute the previous noisy sample x_t -> x_t-1
17581747 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
17591748
1760- if num_channels_unet == 4 :
1761- init_latents_proper = image_latents
1762- if self .do_classifier_free_guidance :
1763- init_mask , _ = mask .chunk (2 )
1764- else :
1765- init_mask = mask
1749+ init_latents_proper = image_latents
1750+ if self .do_classifier_free_guidance :
1751+ init_mask , _ = mask .chunk (2 )
1752+ else :
1753+ init_mask = mask
17661754
1767- if i < len (timesteps ) - 1 :
1768- noise_timestep = timesteps [i + 1 ]
1769- init_latents_proper = self .scheduler .add_noise (
1770- init_latents_proper , noise , torch .tensor ([noise_timestep ])
1771- )
1755+ if i < len (timesteps ) - 1 :
1756+ noise_timestep = timesteps [i + 1 ]
1757+ init_latents_proper = self .scheduler .add_noise (
1758+ init_latents_proper , noise , torch .tensor ([noise_timestep ])
1759+ )
17721760
1773- latents = (1 - init_mask ) * init_latents_proper + init_mask * latents
1761+ latents = (1 - init_mask ) * init_latents_proper + init_mask * latents
17741762
17751763 if callback_on_step_end is not None :
17761764 callback_kwargs = {}
0 commit comments