diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index e4b4295ef7d1..2b0fb9ae6670 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -1582,6 +1582,293 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return pipeline, state +class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): + expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] + model_name = "stable-diffusion-xl" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + ("control_image", None), + ("control_guidance_start", 0.0), + ("control_guidance_end", 1.0), + ("controlnet_conditioning_scale", 1.0), + ("control_mode", 0), + ("guess_mode", False), + ("num_images_per_prompt", 1), + ("guidance_scale", 5.0), + ("guidance_rescale", 0.0), + ("cross_attention_kwargs", None), + ("generator", None), + ("eta", 0.0), + ("guider_kwargs", None), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + "latents", + "batch_size", + "timesteps", + "num_inference_steps", + "prompt_embeds", + "negative_prompt_embeds", + "add_time_ids", + "negative_add_time_ids", + "pooled_prompt_embeds", + "negative_pooled_prompt_embeds", + "timestep_cond", + "mask", + "noise", + "image_latents", + "crops_coords", + ] + + @property + def intermediates_outputs(self) -> List[str]: + return ["latents"] + + def __init__(self): + super().__init__() + self.components["guider"] = CFGGuider() + self.components["controlnet_guider"] = CFGGuider() + self.components["scheduler"] = None + self.components["unet"] = None + self.components["controlnet"] = None + control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) + self.auxiliaries["control_image_processor"] = control_image_processor + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + guidance_scale = state.get_input("guidance_scale") + guidance_rescale = state.get_input("guidance_rescale") + cross_attention_kwargs = state.get_input("cross_attention_kwargs") + guider_kwargs = state.get_input("guider_kwargs") + generator = state.get_input("generator") + eta = state.get_input("eta") + num_images_per_prompt = state.get_input("num_images_per_prompt") + # controlnet-specific inputs + control_image = state.get_input("control_image") + control_guidance_start = state.get_input("control_guidance_start") + control_guidance_end = state.get_input("control_guidance_end") + controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale") + control_mode = state.get_input("control_mode") + guess_mode = state.get_input("guess_mode") + + batch_size = state.get_intermediate("batch_size") + latents = state.get_intermediate("latents") + timesteps = state.get_intermediate("timesteps") + num_inference_steps = state.get_intermediate("num_inference_steps") + + prompt_embeds = state.get_intermediate("prompt_embeds") + negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds") + pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds") + negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds") + add_time_ids = state.get_intermediate("add_time_ids") + negative_add_time_ids = state.get_intermediate("negative_add_time_ids") + + timestep_cond = state.get_intermediate("timestep_cond") + + # inpainting + mask = state.get_intermediate("mask") + noise = state.get_intermediate("noise") + image_latents = state.get_intermediate("image_latents") + crops_coords = state.get_intermediate("crops_coords") + + device = pipeline._execution_device + + height, width = latents.shape[-2:] + height = height * pipeline.vae_scale_factor + width = width * pipeline.vae_scale_factor + + # prepare controlnet inputs + controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + + global_pool_conditions = controlnet.config.global_pool_conditions + guess_mode = guess_mode or global_pool_conditions + + num_control_type = controlnet.config.num_control_type + + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + control_type = [0 for _ in range(num_control_type)] + for control_idx in control_mode: + control_type[control_idx] = 1 + + control_type = torch.Tensor(control_type) + + for idx, _ in enumerate(control_image): + control_image[idx] = pipeline.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + ) + height, width = control_image[idx].shape[-2:] + + controlnet_keep = [] + for i in range(len(timesteps)): + controlnet_keep.append( + 1.0 + - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) + ) + + # Prepare conditional inputs for unet using the guider + # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + guider_kwargs = guider_kwargs or {} + guider_kwargs = { + **guider_kwargs, + "disable_guidance": disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.guider.set_guider(pipeline, guider_kwargs) + prompt_embeds = pipeline.guider.prepare_input( + prompt_embeds, + negative_prompt_embeds, + ) + add_time_ids = pipeline.guider.prepare_input( + add_time_ids, + negative_add_time_ids, + ) + pooled_prompt_embeds = pipeline.guider.prepare_input( + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds, + "time_ids": add_time_ids, + } + + # Prepare conditional inputs for controlnet using the guider + controlnet_disable_guidance = True if disable_guidance or guess_mode else False + controlnet_guider_kwargs = guider_kwargs or {} + controlnet_guider_kwargs = { + **controlnet_guider_kwargs, + "disable_guidance": controlnet_disable_guidance, + "guidance_scale": guidance_scale, + "guidance_rescale": guidance_rescale, + "batch_size": batch_size, + } + pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs) + controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds) + controlnet_added_cond_kwargs = { + "text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds), + "time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids), + } + for idx, _ in enumerate(control_image): + control_image[idx] = pipeline.controlnet_guider.prepare_input(control_image[idx], control_image[idx]) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + + control_type = ( + control_type.reshape(1, -1) + .to(device, dtype=prompt_embeds.dtype) + .repeat(batch_size * num_images_per_prompt * 2, 1) + ) + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # prepare latents for unet using the guider + latent_model_input = pipeline.guider.prepare_input(latents, latents) + + # prepare latents for controlnet using the guider + control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = pipeline.controlnet( + pipeline.scheduler.scale_model_input(control_model_input, t), + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + control_type=control_type, + control_type_idx=control_mode, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + # when we apply guidance for unet, but not for controlnet: + # add 0 to the unconditional batch + down_block_res_samples = pipeline.guider.prepare_input( + down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples] + ) + mid_block_res_sample = pipeline.guider.prepare_input( + mid_block_res_sample, torch.zeros_like(mid_block_res_sample) + ) + + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + # perform guidance + noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if mask is not None and image_latents is not None: + init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0] + init_latents_proper = image_latents + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = pipeline.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + pipeline.guider.reset_guider(pipeline) + pipeline.controlnet_guider.reset_guider(pipeline) + state.add_intermediate("latents", latents) + + return pipeline, state + class StableDiffusionXLDecodeLatentsStep(PipelineBlock): expected_components = ["vae"] model_name = "stable-diffusion-xl"